diff --git a/configs/deit/README.md b/configs/deit/README.md
new file mode 100644
index 000000000..b8cdc0531
--- /dev/null
+++ b/configs/deit/README.md
@@ -0,0 +1,86 @@
+
+# DeiT
+> [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877)
+
+## Introduction
+
+DeiT: Data-efficient Image Transformers
+
+## Results
+
+**Implementation and configs for training were taken and adjusted from [this repository](https://gitee.com/cvisionlab/models/tree/deit/release/research/cv/DeiT), which implements Twins models in mindspore.**
+
+Our reproduced model performance on ImageNet-1K is reported as follows.
+
+
+
+| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
+|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
+| deit_base | Converted from PyTorch | 81.62 | 95.58 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_base_patch16_224.ckpt) |
+| deit_base | 8xRTX3090 | 72.29 | 89.93 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/deit_base_patch16_224_acc%3D0.725.ckpt)
+| deit_small | Converted from PyTorch | 79.39 | 94.80 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_small_patch16_224.ckpt) |
+| deit_tiny | Converted from PyTorch | 71.58 | 90.76 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/deit/deit_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_tiny_patch16_224.ckpt) |
+
+
+
+#### Notes
+
+- Context: The weights in the table were taken from [official repository](https://github.com/facebookresearch/deit) and converted to mindspore format
+- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.
+
+## Quick Start
+
+### Preparation
+
+#### Installation
+Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.
+
+#### Dataset Preparation
+Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.
+
+### Training
+
+* Distributed Training
+
+
+```shell
+# distrubted training on multiple GPU/Ascend devices
+mpirun -n 8 python train.py --config configs/deit/deit_b_gpu.yaml --data_dir /path/to/imagenet --distributed True
+```
+
+> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.
+
+Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.
+
+For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).
+
+**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python train.py --config configs/deit/deit_b_gpu.yaml --data_dir /path/to/dataset --distribute False
+```
+
+### Validation
+
+To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.
+
+```shell
+python validate.py -c configs/deit/deit_b_gpu.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
+```
+
+### Deployment
+
+Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV.
+
+## References
+
+Paper - https://arxiv.org/pdf/2012.12877.pdf
+
+Official repo - https://github.com/facebookresearch/deit
+
+Mindspore implementation - https://gitee.com/cvisionlab/models/tree/deit/release/research/cv/DeiT
diff --git a/configs/deit/deit_b_gpu.yaml b/configs/deit/deit_b_gpu.yaml
new file mode 100644
index 000000000..5224bc752
--- /dev/null
+++ b/configs/deit/deit_b_gpu.yaml
@@ -0,0 +1,72 @@
+# system
+mode: 0
+distribute: False
+num_parallel_workers: 4
+val_while_train: True
+
+# dataset
+dataset: 'imagenet'
+data_dir: 'path/to/imagenet/'
+shuffle: True
+dataset_download: False
+batch_size: 64
+drop_remainder: True
+train_split: val/
+val_split: val/
+
+# augmentation
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+auto_augment: 'randaug-m9-mstd0.5-inc1'
+interpolation: bicubic
+re_prob: 0.25
+re_value: 'random'
+cutmix: 1.0
+mixup: 0.8
+mixup_prob: 1.0
+mixup_mode: batch
+mixup_off_epoch: 0
+switch_prob: 0.5
+crop_pct: 0.875
+color_jitter: 0.3
+
+# model
+model: 'deit_base'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+
+keep_checkpoint_max: 10
+ckpt_save_dir: './ckpt'
+
+epoch_size : 300
+dataset_sink_mode: True
+amp_level: O2
+ema: False
+ema_decay: 0.99996
+clip_grad: False
+clip_value: 5.0
+
+# loss
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler
+lr_scheduler: 'cosine_decay'
+min_lr: 1.0e-5
+lr: 0.0005
+warmup_epochs: 5
+warmup_factor: 0.002
+decay_epochs: 30
+decay_rate: 0.1
+
+# optimizer
+opt: 'adamw'
+filter_bias_and_bn: True
+eps: 1.0e-8
+momentum: 0.9
+weight_decay: 0.05
+dynamic_loss_scale: True
+use_nesterov: False
diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py
index d0521efff..6108a8ec9 100644
--- a/mindcv/models/__init__.py
+++ b/mindcv/models/__init__.py
@@ -3,6 +3,7 @@
bit,
convit,
convnext,
+ deit,
densenet,
dpn,
edgenext,
@@ -45,6 +46,7 @@
from .bit import *
from .convit import *
from .convnext import *
+from .deit import *
from .densenet import *
from .dpn import *
from .edgenext import *
@@ -91,6 +93,7 @@
__all__.extend(bit.__all__)
__all__.extend(convit.__all__)
__all__.extend(convnext.__all__)
+__all__.extend(deit.__all__)
__all__.extend(densenet.__all__)
__all__.extend(dpn.__all__)
__all__.extend(edgenext.__all__)
diff --git a/mindcv/models/deit.py b/mindcv/models/deit.py
new file mode 100644
index 000000000..c3e6738c0
--- /dev/null
+++ b/mindcv/models/deit.py
@@ -0,0 +1,323 @@
+"""Visual Transformer definition
+
+"""
+import math
+import warnings
+from functools import partial
+
+import numpy as np
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore.ops import Add, BatchMatMul, Erfinv, Mul, Reshape, Tile, Transpose, Zeros, clip_by_value
+from mindspore.ops.function import concat
+
+from .layers.drop_path import DropPath
+from .layers.identity import Identity
+from .layers.mlp import Mlp
+from .layers.patch_embed import PatchEmbed
+from .registry import register_model
+from .utils import load_pretrained
+
+__all__ = [
+ "deit_base",
+ "deit_small",
+ "deit_tiny"
+]
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000,
+ 'first_conv': 'patch_embed.proj',
+ 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ "deit_base": _cfg(
+ url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_base_patch16_224.ckpt"
+ ),
+ "deit_small": _cfg(
+ url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_small_patch16_224.ckpt"
+ ),
+ "deit_tiny": _cfg(
+ url="https://storage.googleapis.com/huawei-mindspore-hk/DeiT/Converted/deit_tiny_patch16_244.ckpt"
+ )
+}
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ """Trunc normal function"""
+ # Cut & paste from PyTorch official master until it's
+ # in a few official releases - RW
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn('mean is more than 2 std from [a, b] '
+ 'in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ la = norm_cdf((a - mean) / std)
+ ub = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ shape = tensor.shape
+ numpy_tensor = np.random.uniform(2 * la - 1, 2 * ub - 1, shape)
+
+ tensor_tmp = ms.Tensor(numpy_tensor, ms.float32)
+ tensor.assign_value(tensor_tmp)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor = Erfinv()(tensor)
+
+ # Transform to proper mean, std
+ tensor = Mul()(tensor, std * math.sqrt(2.))
+ tensor = Add()(tensor, mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor = clip_by_value(tensor, a, b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ """
+ Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+class Attention(nn.Cell):
+ """
+ The Attention layer
+ The Pytorch implementation can be found by this link:
+ https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L202
+ """
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
+ self.attn_drop = nn.Dropout(1.0 - attn_drop)
+ self.proj = nn.Dense(dim, dim)
+ self.proj_drop = nn.Dropout(1.0 - proj_drop)
+
+ self.reshape = Reshape()
+ self.matmul = BatchMatMul()
+ self.transpose = Transpose()
+ self.softmax = nn.Softmax(axis=-1)
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ b, n, c = x.shape
+ qkv = self.reshape(
+ self.qkv(x), (b, n, 3, self.num_heads, c // self.num_heads)
+ )
+ qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = self.matmul(q, self.transpose(k, (0, 1, 3, 2))) * self.scale
+
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn)
+
+ x = self.transpose(self.matmul(attn, v), (0, 2, 1, 3))
+ x = self.reshape(x, (b, n, c))
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Cell):
+ """
+ The Block layer
+ The Pytorch implementation can be found by this link:
+ https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L240
+ """
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
+ qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=partial(nn.GELU, approximate=False),
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer((dim,))
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ self.norm2 = norm_layer((dim,))
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+ act_layer=act_layer, drop=drop)
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Cell):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
+ num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False,
+ qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None,
+ norm_layer=nn.LayerNorm,
+ act_mlp_layer=partial(nn.GELU, approximate=False)
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ if hybrid_backbone is not None:
+ raise NotImplementedError(
+ 'This Layer was not iimplementes because all models from deit does not use it'
+ )
+ self.patch_embed = PatchEmbed(
+ image_size=img_size, patch_size=patch_size,
+ in_chans=in_chans, embed_dim=embed_dim
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = ms.Parameter(
+ Zeros()((1, 1, embed_dim), ms.float32)
+ )
+ self.pos_embed = ms.Parameter(
+ Zeros()((1, num_patches + 1, embed_dim), ms.float32)
+ )
+ self.tile = Tile()
+ self.pos_drop = nn.Dropout(1.0 - drop_rate)
+
+ dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.CellList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
+ norm_layer=norm_layer, act_layer=act_mlp_layer
+ )
+ for i in range(depth)])
+ self.norm = norm_layer((embed_dim,))
+
+ # Classifier head
+ self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self.cells(), self._init_weights)
+
+ def apply(self, layer, fn):
+ for l_ in layer:
+ if hasattr(l_, 'cells') and len(l_.cells()) >= 1:
+ self.apply(l_.cells(), fn)
+ else:
+ fn(l_)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Dense):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Dense) and m.bias is not None:
+ constant_init = ms.common.initializer.Constant(value=0)
+ constant_init(m.bias)
+
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes):
+ self.num_classes = num_classes
+ self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else Identity()
+
+ def forward_features(self, x):
+ """Forward features"""
+ b = x.shape[0]
+
+ x = self.patch_embed(x)
+ cls_tokens = self.tile(self.cls_token, (b, 1, 1))
+ x = concat((cls_tokens, x), axis=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x[:, 0]
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+@register_model
+def deit_base(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, in_chans=in_channels, num_classes=num_classes,
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
+ default_cfg = default_cfgs['deit_base']
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model
+
+
+@register_model
+def deit_tiny(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, in_chans=in_channels, num_classes=num_classes,
+ embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
+ default_cfg = default_cfgs['deit_tiny']
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model
+
+
+@register_model
+def deit_small(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, in_chans=in_channels, num_classes=num_classes,
+ embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
+ default_cfg = default_cfgs['deit_small']
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model