diff --git a/configs/deit/README.md b/configs/deit/README.md new file mode 100644 index 000000000..b8cdc0531 --- /dev/null +++ b/configs/deit/README.md @@ -0,0 +1,86 @@ + +# DeiT +> [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) + +## Introduction + +DeiT: Data-efficient Image Transformers + +## Results + +**Implementation and configs for training were taken and adjusted from [this repository](https://gitee.com/cvisionlab/models/tree/deit/release/research/cv/DeiT), which implements Twins models in mindspore.** + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | +|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| deit_base | Converted from PyTorch | 81.62 | 95.58 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_base_patch16_224.ckpt) | +| deit_base | 8xRTX3090 | 72.29 | 89.93 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/deit_base_patch16_224_acc%3D0.725.ckpt) +| deit_small | Converted from PyTorch | 79.39 | 94.80 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_small_patch16_224.ckpt) | +| deit_tiny | Converted from PyTorch | 71.58 | 90.76 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_tiny_patch16_224.ckpt) | + +
+ +#### Notes + +- Context: The weights in the table were taken from [official repository](https://github.com/facebookresearch/deit) and converted to mindspore format +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + +## Quick Start + +### Preparation + +#### Installation +Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV. + +#### Dataset Preparation +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + +* Distributed Training + + +```shell +# distrubted training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/deit/deit_b_gpu.yaml --data_dir /path/to/imagenet --distributed True +``` + +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/deit/deit_b_gpu.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +```shell +python validate.py -c configs/deit/deit_b_gpu.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV. + +## References + +Paper - https://arxiv.org/pdf/2012.12877.pdf + +Official repo - https://github.com/facebookresearch/deit + +Mindspore implementation - https://gitee.com/cvisionlab/models/tree/deit/release/research/cv/DeiT diff --git a/configs/deit/deit_b_gpu.yaml b/configs/deit/deit_b_gpu.yaml new file mode 100644 index 000000000..5224bc752 --- /dev/null +++ b/configs/deit/deit_b_gpu.yaml @@ -0,0 +1,72 @@ +# system +mode: 0 +distribute: False +num_parallel_workers: 4 +val_while_train: True + +# dataset +dataset: 'imagenet' +data_dir: 'path/to/imagenet/' +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True +train_split: val/ +val_split: val/ + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +auto_augment: 'randaug-m9-mstd0.5-inc1' +interpolation: bicubic +re_prob: 0.25 +re_value: 'random' +cutmix: 1.0 +mixup: 0.8 +mixup_prob: 1.0 +mixup_mode: batch +mixup_off_epoch: 0 +switch_prob: 0.5 +crop_pct: 0.875 +color_jitter: 0.3 + +# model +model: 'deit_base' +num_classes: 1000 +pretrained: False +ckpt_path: '' + +keep_checkpoint_max: 10 +ckpt_save_dir: './ckpt' + +epoch_size : 300 +dataset_sink_mode: True +amp_level: O2 +ema: False +ema_decay: 0.99996 +clip_grad: False +clip_value: 5.0 + +# loss +loss: 'CE' +label_smoothing: 0.1 + +# lr scheduler +lr_scheduler: 'cosine_decay' +min_lr: 1.0e-5 +lr: 0.0005 +warmup_epochs: 5 +warmup_factor: 0.002 +decay_epochs: 30 +decay_rate: 0.1 + +# optimizer +opt: 'adamw' +filter_bias_and_bn: True +eps: 1.0e-8 +momentum: 0.9 +weight_decay: 0.05 +dynamic_loss_scale: True +use_nesterov: False diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index d0521efff..6108a8ec9 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -3,6 +3,7 @@ bit, convit, convnext, + deit, densenet, dpn, edgenext, @@ -45,6 +46,7 @@ from .bit import * from .convit import * from .convnext import * +from .deit import * from .densenet import * from .dpn import * from .edgenext import * @@ -91,6 +93,7 @@ __all__.extend(bit.__all__) __all__.extend(convit.__all__) __all__.extend(convnext.__all__) +__all__.extend(deit.__all__) __all__.extend(densenet.__all__) __all__.extend(dpn.__all__) __all__.extend(edgenext.__all__) diff --git a/mindcv/models/deit.py b/mindcv/models/deit.py new file mode 100644 index 000000000..c3e6738c0 --- /dev/null +++ b/mindcv/models/deit.py @@ -0,0 +1,323 @@ +"""Visual Transformer definition + +""" +import math +import warnings +from functools import partial + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import Add, BatchMatMul, Erfinv, Mul, Reshape, Tile, Transpose, Zeros, clip_by_value +from mindspore.ops.function import concat + +from .layers.drop_path import DropPath +from .layers.identity import Identity +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed +from .registry import register_model +from .utils import load_pretrained + +__all__ = [ + "deit_base", + "deit_small", + "deit_tiny" +] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + "deit_base": _cfg( + url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_base_patch16_224.ckpt" + ), + "deit_small": _cfg( + url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_small_patch16_224.ckpt" + ), + "deit_tiny": _cfg( + url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_tiny_patch16_244.ckpt" + ) +} + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + """Trunc normal function""" + # Cut & paste from PyTorch official master until it's + # in a few official releases - RW + # Method based on + # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn('mean is more than 2 std from [a, b] ' + 'in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + la = norm_cdf((a - mean) / std) + ub = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + shape = tensor.shape + numpy_tensor = np.random.uniform(2 * la - 1, 2 * ub - 1, shape) + + tensor_tmp = ms.Tensor(numpy_tensor, ms.float32) + tensor.assign_value(tensor_tmp) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor = Erfinv()(tensor) + + # Transform to proper mean, std + tensor = Mul()(tensor, std * math.sqrt(2.)) + tensor = Add()(tensor, mean) + + # Clamp to ensure it's in the proper range + tensor = clip_by_value(tensor, a, b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + """ + Fills the input Tensor with values drawn from a truncated + normal distribution. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class Attention(nn.Cell): + """ + The Attention layer + The Pytorch implementation can be found by this link: + https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L202 + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.attn_drop = nn.Dropout(1.0 - attn_drop) + self.proj = nn.Dense(dim, dim) + self.proj_drop = nn.Dropout(1.0 - proj_drop) + + self.reshape = Reshape() + self.matmul = BatchMatMul() + self.transpose = Transpose() + self.softmax = nn.Softmax(axis=-1) + + def construct(self, *inputs, **kwargs): + x = inputs[0] + b, n, c = x.shape + qkv = self.reshape( + self.qkv(x), (b, n, 3, self.num_heads, c // self.num_heads) + ) + qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = self.matmul(q, self.transpose(k, (0, 1, 3, 2))) * self.scale + + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = self.transpose(self.matmul(attn, v), (0, 2, 1, 3)) + x = self.reshape(x, (b, n, c)) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Cell): + """ + The Block layer + The Pytorch implementation can be found by this link: + https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L240 + """ + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, + qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=partial(nn.GELU, approximate=False), + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + self.norm2 = norm_layer((dim,)) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + def construct(self, *inputs, **kwargs): + x = inputs[0] + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Cell): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, + num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, + qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, + norm_layer=nn.LayerNorm, + act_mlp_layer=partial(nn.GELU, approximate=False) + ): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + raise NotImplementedError( + 'This Layer was not iimplementes because all models from deit does not use it' + ) + self.patch_embed = PatchEmbed( + image_size=img_size, patch_size=patch_size, + in_chans=in_chans, embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ms.Parameter( + Zeros()((1, 1, embed_dim), ms.float32) + ) + self.pos_embed = ms.Parameter( + Zeros()((1, num_patches + 1, embed_dim), ms.float32) + ) + self.tile = Tile() + self.pos_drop = nn.Dropout(1.0 - drop_rate) + + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_mlp_layer + ) + for i in range(depth)]) + self.norm = norm_layer((embed_dim,)) + + # Classifier head + self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self.cells(), self._init_weights) + + def apply(self, layer, fn): + for l_ in layer: + if hasattr(l_, 'cells') and len(l_.cells()) >= 1: + self.apply(l_.cells(), fn) + else: + fn(l_) + + def _init_weights(self, m): + if isinstance(m, nn.Dense): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Dense) and m.bias is not None: + constant_init = ms.common.initializer.Constant(value=0) + constant_init(m.bias) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else Identity() + + def forward_features(self, x): + """Forward features""" + b = x.shape[0] + + x = self.patch_embed(x) + cls_tokens = self.tile(self.cls_token, (b, 1, 1)) + x = concat((cls_tokens, x), axis=1) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def construct(self, *inputs, **kwargs): + x = inputs[0] + x = self.forward_features(x) + x = self.head(x) + return x + + +@register_model +def deit_base(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs): + model = VisionTransformer( + patch_size=16, in_chans=in_channels, num_classes=num_classes, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) + default_cfg = default_cfgs['deit_base'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def deit_tiny(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs): + model = VisionTransformer( + patch_size=16, in_chans=in_channels, num_classes=num_classes, + embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) + default_cfg = default_cfgs['deit_tiny'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model + + +@register_model +def deit_small(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs): + model = VisionTransformer( + patch_size=16, in_chans=in_channels, num_classes=num_classes, + embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) + default_cfg = default_cfgs['deit_small'] + + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model