From c21d344952ae545bc5a65e50501121baf4ddda8b Mon Sep 17 00:00:00 2001
From: lyg95
Date: Wed, 7 Jun 2023 16:57:29 +0800
Subject: [PATCH] Add the config file of mixed precision quantization for
ResNet50
Signed-off-by: lyg95
---
README.md | 48 ++-
.../main_quantization_lsq_distillation.py | 307 ++++++++++++++++++
...ization_lsq_first_last_layer_int8.prototxt | 2 +-
...first_last_layer_int8_load_weight.prototxt | 2 +-
...ast_layer_int8_load_weight_resume.prototxt | 2 +-
..._first_last_layer_int8_per_tensor.prototxt | 2 +-
...ayer_int8_per_tensor_distillation.prototxt | 2 +-
...layer_int8_per_tensor_load_weight.prototxt | 2 +-
.../efficientnetb0_quantization_ptq.prototxt | 2 +-
...onsinelr_distillation_load_weight.prototxt | 2 +-
...esnet50_autoslim_quantization_ptq.prototxt | 2 +-
.../resnet/resnet50_distillation.prototxt | 45 +++
...r_distillation_load_weight_80.754.prototxt | 68 ++++
run_cli.sh | 0
.../proto/model_optimizer_torch.proto | 2 +-
src/model_optimizer/utils/__init__.py | 1 +
16 files changed, 472 insertions(+), 17 deletions(-)
create mode 100755 examples/classifier_imagenet/main_quantization_lsq_distillation.py
create mode 100644 examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt
create mode 100644 examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt
mode change 100644 => 100755 run_cli.sh
diff --git a/README.md b/README.md
index 00a8249..e48e6d4 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Model Optimizer
-[](https://dev.azure.com/Adlik/GitHub/_build/results?buildId=3472&view=results)
+[](https://dev.azure.com/Adlik/GitHub/_build/latest?definitionId=2&branchName=main)
[](https://app.bors.tech/repositories/65566)
[](https://opensource.org/licenses/Apache-2.0)
@@ -22,9 +22,17 @@ of the models. Note that activations use the same observer as weights unless oth
| resnet50 | 76.13 | 75.580 | 75.612 | 75.99 |
| mobilenetv2 | 71.878 | 70.730(act=percentile) | 70.816 | 71.11 |
+For quantization, we explored low-bit model quantization, where weights are quantized to 3 bits and activations are
+quantized to 4 bits, using the lsq algorithm to quantize resnet50. Compared with the FP32 model, the quantized resnet50
+has no accuracy loss and has higher accuracy than the original model. Our quantized resnet50 model achieves 77.34%
+accuracy on the ImageNet dataset. Here we supply the quantized model in
+[model_zoo](https://github.com/Adlik/model_zoo) for testing. Besides, we submit our quantized model to
+[paperswithcode](https://paperswithcode.com/sota/quantization-on-imagenet?tag_filter=447), this is the current
+state-of-art result in low-bit quantization.
+
For AutoSlim, we give the pruning effect of the resnet50 model on the ImageNet dataset.
-| FLOPs(G) |Params(M) |Size(MB)| Top-1| Acc |Input Size|
+| Model | FLOPs(G) |Params(M) |Size(MB)| Top-1 Acc |Input Size|
|:---:| :---: | :---: | :---: | :---: | :---: |
|ResNet5 |4.12 |25.56 |98 |77.39% |224|
|ResNet-50 0.75× |2.35 |14.77 |57 |75.87% |224|
@@ -37,8 +45,8 @@ For AutoSlim, we give the pruning effect of the resnet50 model on the ImageNet d
The following table shows the effect of AutoSlim on YOLOv5m backbone pruning.
-| FLOPs(G) |Params(M) |Size(MB)|mAPval 0.5:0.95| Input Size|
-| :---: | :---: | :---: | :---: | :---: |
+| Model | FLOPs(G) |Params(M) |Size(MB)|mAPval 0.5:0.95| Input Size|
+| :---: | :---: | :---: | :---: | :---: | :---: |
|YOLOv5m |24.5 |21.2 |81 |44.4 |640|
|AutoSlim-YOLOv5m | 16.7(-31.8%)| 17.8(-16%)| 69(-14.8%)| 42.0(-2.4)| 640|
@@ -108,7 +116,7 @@ Refer to the paper [Distilling the Knowledge in a Neural Network](https://arxiv.
### 1.4 Pruning
AutoSlim can prune the model automatically, which it can achieve better model accuracy under
-limited resource conditions (such as FLOPs, latency, memory footprint, or model size)。In AutoSlim,
+limited resource conditions (such as FLOPs, latency, memory footprint, or model size). In AutoSlim,
it can be divided into several steps. The first step is to train a slimmable model for a few epochs
(e.g., 10% to 20% of full training epochs) to quickly get a benchmark performance estimator; Then we
evaluate the trained slimmable model and greedily slim the layer with minimal accuracy drop on a
@@ -157,14 +165,14 @@ python -m pip install -r requirements.txt
There are two installation methods.
-1、Python wheel installation
+- Python wheel installation
```sh
cd model_optimizer
python setup.py install
```
-2、Developer mode installation
+- Developer mode installation
```sh
chmod +x *.sh
@@ -358,6 +366,32 @@ examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_search.prototxt
examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_retrain_100epochs_lr0.4_decay5e05_momentum0.9_ls0.1.prototxt
```
+### 3.5 Low-bit quantization
+
+Here is a detailed introduction to the steps to reproduce the quantitative results of our model. The quantization process
+can be divided into two steps.
+
+(1) model training
+
+First of all, we need a fully trained model as the base model for quantization. Here you can use the following
+configuration file to train a higher-precision model.
+
+```sh
+examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt
+```
+
+When training the model, we will use resnet50d as the teacher model to distill the student model.
+
+(2) model quantization
+
+When quantizing, load a trained model, then use the LSQ algorithm to quantize the model. To improve
+the performance of the quantized model, we will use the distillation method in the quantization process. The following
+configuration file is the configuration of our model quantization.
+
+```sh
+examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt
+```
+
## Acknowledgement
We would like to thank them for their excellent open-source work.
diff --git a/examples/classifier_imagenet/main_quantization_lsq_distillation.py b/examples/classifier_imagenet/main_quantization_lsq_distillation.py
new file mode 100755
index 0000000..4220364
--- /dev/null
+++ b/examples/classifier_imagenet/main_quantization_lsq_distillation.py
@@ -0,0 +1,307 @@
+# Copyright 2023 ZTE corporation. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+"""
+quantization train with PTQ or QAT
+"""
+import time
+import copy
+import datetime
+import os
+from torch.backends import cudnn
+import torch.distributed as dist
+from torch import nn
+import torch
+from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx, fuse_fx
+from model_optimizer.core import (get_base_parser, get_hyperparam, get_freer_gpu, main_s1_set_seed,
+ main_s2_start_worker, display_model, process_model, resume_model,
+ distributed_model, get_optimizer, get_summary_writer,
+ get_model_info, get_lr_scheduler, validate, train, save_checkpoint,
+ save_torchscript_model, save_file_on_master, save_yaml_file_on_master,
+ print_on_master)
+from model_optimizer.proto import model_optimizer_torch_pb2 as eppb
+from model_optimizer.quantizer import (get_qconfig, get_lsq_qconfig, get_model_quantize_bit_config,
+ clip_model_weight_in_quant_min_max)
+from model_optimizer.losses import KLDivergence
+from model_optimizer.datasets import DataloaderFactory
+from model_optimizer.models import get_model_from_source, get_teacher_model
+
+best_acc1 = 0
+
+
+def main():
+ """
+ main process
+ Returns:
+
+ """
+ parser = get_base_parser()
+ args = parser.parse_args()
+ hp = get_hyperparam(args)
+ if hp.gpu_id == eppb.GPU.ANY:
+ args.gpu = get_freer_gpu()
+ elif hp.gpu_id == eppb.GPU.NONE:
+ args.gpu = None # TODO: test
+
+ print("Start training")
+ start_time = time.time()
+ main_s1_set_seed(hp)
+ main_s2_start_worker(main_worker, args, hp)
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(f'Training time {total_time_str}')
+
+
+def main_worker(gpu, ngpus_per_node, args): # pylint: disable=too-many-branches,too-many-statements
+ """
+ main worker on per gpu
+ Args:
+ gpu:
+ ngpus_per_node:
+ args:
+
+ Returns:
+
+ """
+ args.gpu = gpu
+ if args.gpu is not None:
+ print(f"Use GPU: {args.gpu} for training")
+ args.hp = get_hyperparam(args)
+
+ if args.hp.quantization.post_training_quantize and args.distributed:
+ raise RuntimeError("Post training quantization should not be performed "
+ "on distributed mode")
+
+ if args.distributed:
+ if args.hp.multi_gpu.dist_url == "env://" and args.hp.multi_gpu.rank == -1:
+ args.hp.multi_gpu.rank = int(os.environ["RANK"])
+ if args.hp.multi_gpu.multiprocessing_distributed:
+ # For multiprocessing distributed training, rank needs to be the
+ # global rank among all the processes
+ args.hp.multi_gpu.rank = args.hp.multi_gpu.rank * ngpus_per_node + gpu
+ dist.init_process_group(backend=args.hp.multi_gpu.dist_backend, init_method=args.hp.multi_gpu.dist_url,
+ world_size=args.world_size, rank=args.hp.multi_gpu.rank)
+ # create model
+ if args.hp.pretrained:
+ print(f"=> using pre-trained model '{args.hp.arch}'")
+ else:
+ print(f"=> creating model '{args.hp.arch}'")
+ load_quantize_model = args.hp.quantization.quantize and (not args.hp.quantization.quantize_fx)
+ model = get_model_from_source(args.hp.arch, args.hp.model_source, args.hp.pretrained, args.hp.width_mult,
+ args.hp.depth_mult, load_quantize_model, args.hp.is_subnet,
+ args.hp.auto_slim.channel_config_path if args.hp.HasField('auto_slim') else None)
+ # Teacher model
+ if args.hp.distill.teacher_model.arch:
+ teacher_model = get_teacher_model(args.hp.distill.teacher_model.arch, args.hp.distill.teacher_model.source)
+ else:
+ teacher_model = None
+ if args.gpu == 0 and args.hp.multi_gpu.rank in [-1, 0]:
+ print(args)
+ print('model:\n=========\n')
+ display_model(model)
+
+ process_model(model, args)
+
+ if args.hp.quantization.quantize:
+ torch.backends.quantized.engine = args.hp.quantization.backend
+
+ if args.hp.quantization.quantize and not args.hp.quantization.post_training_quantize:
+ if args.hp.quantization.quantize_fx:
+ qconfig_dict = get_lsq_qconfig(model, args.hp.quantization, args.hp.quantization.backend)
+ model = prepare_qat_fx(model, qconfig_dict)
+ else:
+ qconfig = get_qconfig(args.hp.quantization, args.hp.quantization.backend)
+ model.fuse_model()
+ model.qconfig = qconfig
+ torch.quantization.prepare_qat(model, inplace=True)
+
+ df_calibrate = DataloaderFactory(args)
+ calibrate_train_loader, _, _ = df_calibrate.product_train_val_loader(
+ df_calibrate.imagenet2012,
+ num_batches=1,
+ use_val_trans=True,
+ batch_same_data=True)
+ # simple_train_loader = df_simple.product_train_loader(df_simple.imagenet2012)
+ model.apply(torch.quantization.enable_observer)
+ model.apply(torch.quantization.disable_fake_quant)
+
+ print('init scale for lsq')
+ model.apply(torch.quantization.enable_observer)
+ with torch.no_grad():
+ for images, _ in calibrate_train_loader:
+ model(images)
+ del calibrate_train_loader
+ del df_calibrate
+ torch.distributed.barrier()
+ # parallel and multi-gpu
+ model = distributed_model(model, ngpus_per_node, args)
+ if args.distributed:
+ model_without_ddp = model.module
+ else:
+ model_without_ddp = model
+
+ if args.hp.distill.teacher_model.arch:
+ teacher_model = distributed_model(teacher_model, ngpus_per_node, args)
+ # define distillation loss function (distill_criterion)
+ distill_criterion = KLDivergence(args.hp.distill.kl_divergence.temperature,
+ loss_weight=args.hp.distill.kl_divergence.loss_weight)
+
+ # define loss function (criterion)
+ criterion = nn.CrossEntropyLoss(label_smoothing=args.hp.label_smoothing).cuda(args.gpu)
+
+ optimizer = get_optimizer(model, args)
+ cudnn.benchmark = True
+
+ if args.hp.quantization.quantize and args.hp.quantization.post_training_quantize:
+
+ dataload_factory = DataloaderFactory(args)
+ data_loader_calibration, val_loader, train_sampler = dataload_factory.product_train_val_loader(
+ dataload_factory.imagenet2012,
+ args.hp.quantization.num_calibration_batches, use_val_trans=True)
+ model.eval()
+ qconfig = get_qconfig(args.hp.quantization, args.hp.quantization.backend)
+ if args.hp.quantization.quantize_fx:
+ qconfig_dict = {"": qconfig}
+ model = prepare_fx(model, qconfig_dict)
+ else:
+ model.fuse_model()
+ model.qconfig = qconfig
+ torch.quantization.prepare(model, inplace=True)
+ validate(data_loader_calibration, model, criterion, args)
+ model.to(torch.device('cpu'))
+ writer = get_summary_writer(args, ngpus_per_node, model)
+ if args.hp.quantization.quantize_fx:
+ if args.hp.save_jit_trace:
+ bit_config = get_model_quantize_bit_config(model)
+ save_yaml_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_bit_config.yaml', bit_config)
+ model = convert_fx(model)
+ model = fuse_fx(model)
+ if args.hp.save_jit_trace:
+ clip_model_weight_in_quant_min_max(model, bit_config)
+ else:
+ torch.quantization.convert(model, inplace=True)
+ if args.hp.save_jit_trace:
+ filename = 'ptq_quant.jit'
+ save_torchscript_model(model, val_loader, prefix=f'{args.log_name}/{args.arch}',
+ filename=filename)
+ save_file_on_master(args, f'{args.log_name}/{args.arch}_quantized_by_ptq.txt', str(model))
+ acc1, acc5 = validate(val_loader, model, criterion, args, device='cpu')
+ if writer is not None:
+ writer.add_scalar('val/acc1', acc1, 0)
+ writer.add_scalar('val/acc5', acc5, 0)
+ writer.close()
+ return
+
+ dataload_factory = DataloaderFactory(args)
+ train_loader, val_loader, train_sampler = dataload_factory.product_train_val_loader(dataload_factory.imagenet2012)
+ writer = get_summary_writer(args, ngpus_per_node, model)
+ if args.hp.evaluate:
+ if writer is not None:
+ get_model_info(copy.deepcopy(model_without_ddp), args, val_loader)
+ args.batch_num = len(train_loader)
+
+ lr_scheduler = get_lr_scheduler(optimizer, args)
+
+ resume_model(model_without_ddp, args, optimizer, lr_scheduler)
+
+ global best_acc1 # pylint: disable=global-statement
+ best_acc1 = args.best_acc1
+ if args.hp.evaluate:
+ validate(val_loader, model, criterion, args)
+ return
+
+ if args.hp.amp:
+ # Automatic mixed precision training
+ scaler = torch.cuda.amp.GradScaler()
+ else:
+ scaler = None
+
+ if args.hp.quantization.quantize:
+ model.apply(torch.quantization.disable_observer)
+ model.apply(torch.quantization.disable_fake_quant)
+ acc1, acc5 = validate(val_loader, model, criterion, args)
+ print_on_master(args, f'validate after calibration: acc1: {acc1}, acc5: {acc5}')
+ model.apply(torch.quantization.enable_fake_quant)
+
+ for epoch in range(args.start_epoch, args.hp.epochs):
+ if args.distributed:
+ train_sampler.set_epoch(epoch)
+ # train for one epoch
+ epoch_start_time = time.time()
+ train(train_loader, model, criterion, optimizer, epoch, args, writer, scaler,
+ teacher_model=teacher_model, distill_criterion=distill_criterion)
+ lr_scheduler.step()
+ epoch_total_time = time.time() - epoch_start_time
+ total_time_str = str(datetime.timedelta(seconds=int(epoch_total_time)))
+
+ if args.hp.multi_gpu.rank in [-1, 0]:
+ print(f'Epoch[{epoch + 1}/{args.hp.epochs}] total time {total_time_str}')
+ with torch.no_grad():
+ if args.hp.quantization.quantize:
+ if epoch >= args.hp.quantization.num_observer_update_epochs:
+ print('Disabling observer for subseq epochs, epoch = ', epoch)
+ model.apply(torch.quantization.disable_observer)
+ if epoch >= args.hp.quantization.num_batch_norm_update_epochs:
+ print('Freezing BN for subseq epochs, epoch = ', epoch)
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+ # evaluate on validation set
+ if args.hp.quantization.quantize:
+ print('Evaluate QAT model')
+ acc1, acc5 = validate(val_loader, model, criterion, args)
+ if writer is not None:
+ writer.add_scalar('val/acc1', acc1, epoch)
+ writer.add_scalar('val/acc5', acc5, epoch)
+ writer.add_scalar('val/lr', optimizer.param_groups[0]['lr'], epoch)
+ # remember best acc@1 and save checkpoint
+
+ is_best = acc1 > best_acc1
+ best_acc1 = max(acc1, best_acc1)
+ if is_best:
+ print_on_master(args, f'** best at epoch: {epoch + 1}, acc1: {acc1}, acc5: {acc5}')
+ if writer is not None:
+ save_checkpoint({
+ 'epoch': epoch + 1,
+ 'arch': args.arch,
+ 'state_dict': model_without_ddp.state_dict(),
+ 'best_acc1': best_acc1,
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ }, is_best, prefix=f'{args.log_name}/{args.arch}')
+ with torch.no_grad():
+ if args.hp.quantization.quantize:
+ save_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_by_qat_prepared.txt',
+ str(model_without_ddp))
+ quantized_eval_model = copy.deepcopy(model_without_ddp)
+ quantized_eval_model.eval()
+ quantized_eval_model.to(torch.device('cpu'))
+ if args.hp.quantization.quantize_fx:
+ if args.hp.save_jit_trace:
+ bit_config = get_model_quantize_bit_config(quantized_eval_model)
+ save_yaml_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_bit_config.yaml',
+ bit_config)
+ quantized_eval_model = convert_fx(quantized_eval_model)
+ if args.hp.save_jit_trace:
+ clip_model_weight_in_quant_min_max(quantized_eval_model, bit_config)
+ else:
+ torch.quantization.convert(quantized_eval_model, inplace=True)
+ save_file_on_master(args, f'{args.log_name}/{args.arch}_quantized_by_qat.txt',
+ str(quantized_eval_model))
+ print('Evaluate Quantized model')
+ validate(val_loader, quantized_eval_model, criterion, args, device='cpu')
+ if is_best and args.hp.save_jit_trace:
+ save_file_on_master(args, f'{args.log_name}/{args.arch}_save_jit_log.txt',
+ f'best at epoch: {epoch+1}, acc1: {acc1}, acc5: {acc5}')
+ if args.hp.quantization.quantize:
+ save_jit_model = quantized_eval_model
+ filename = 'best_qat_quant.jit'
+ else:
+ save_jit_model = model
+ filename = 'best.jit'
+ save_torchscript_model(save_jit_model, val_loader, prefix=f'{args.log_name}/{args.arch}',
+ filename=filename)
+ del quantized_eval_model
+ if writer is not None:
+ writer.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt
index 4fae1f6..9149d00 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt
@@ -41,7 +41,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt
index c49f264..d6c389e 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt
@@ -42,7 +42,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt
index e93f6cf..195a001 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt
@@ -43,7 +43,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt
index cfa65ea..1486632 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt
@@ -41,7 +41,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt
index 6c0e498..093bf70 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt
@@ -46,7 +46,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt
index 9b8f6cf..3667da7 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt
@@ -42,7 +42,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt
index a43f5d7..99e0f05 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt
@@ -31,7 +31,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: true
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_calibration_batches: 120
activation_quantization_observer {
quantization_method: "quantization_error"
diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt
index 56a99b2..573a233 100644
--- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt
+++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt
@@ -47,7 +47,7 @@ quantization {
quantize: true
quantize_fx: true
post_training_quantize: false
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_observer_update_epochs: 4
num_batch_norm_update_epochs: 99999
activation_quantization_observer {
diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt
index 49703ec..02920fe 100644
--- a/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt
+++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt
@@ -39,7 +39,7 @@ quantization {
quantize: true
quantize_fx: false
post_training_quantize: true
- backend: "fbgemm"
+ backend: TORCH_FBGEMM
num_calibration_batches: 120
activation_quantization_observer {
quantization_method: "quantization_error"
diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt
new file mode 100644
index 0000000..fbfa5dd
--- /dev/null
+++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt
@@ -0,0 +1,45 @@
+main_file: "examples/classifier_imagenet/main_distillation.py"
+arch: "resnet50"
+model_source: TorchVision
+log_name: "multi_gpu"
+data: "/data/imagenet/imagenet-torch"
+debug: false
+lr: 0.0500000007450581
+epochs: 720
+batch_size: 128
+workers: 8
+print_freq: 50
+evaluate: false
+pretrained: false
+seed: 0
+gpu_id: ANY
+multi_gpu {
+ world_size: 1
+ rank: 0
+ dist_url: "tcp://127.0.0.1:23457"
+ dist_backend: "nccl"
+ multiprocessing_distributed: true
+}
+val_batch_size: 256
+warmup {
+ lr_warmup_epochs: 5
+ lr_warmup_decay: 0.009999999776482582
+}
+lr_scheduler: CosineAnnealingLR
+optimizer: SGD
+sgd {
+ weight_decay: 9.999999747378752e-06
+ momentum: 0.8999999761581421
+}
+amp: true
+distill {
+ teacher_model {
+ arch: "resnet50d"
+ source: Timm
+ }
+ kl_divergence {
+ temperature: 1.0
+ reduction: "batchmean"
+ loss_weight: 0.699999988079071
+ }
+}
diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt
new file mode 100644
index 0000000..4a76516
--- /dev/null
+++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt
@@ -0,0 +1,68 @@
+main_file: "examples/classifier_imagenet/main_quantization_lsq_distillation.py"
+arch: "resnet50"
+model_source: TorchVision
+log_name: "quantization_lsq_distillation_first_last_layer_int8_lr_0.001"
+data: "/data/imagenet/imagenet-torch"
+debug: false
+lr: 0.0010000000474974513
+epochs: 50
+batch_size: 32
+workers: 8
+print_freq: 50
+evaluate: false
+pretrained: false
+seed: 0
+gpu_id: ANY
+multi_gpu {
+ world_size: 1
+ rank: 0
+ dist_url: "tcp://127.0.0.1:23457"
+ dist_backend: "nccl"
+ multiprocessing_distributed: true
+}
+weight: "/root/work/model_optimizer_torch_dev_1.03/logger/resnet50_multi_gpu_202206211454_acc_80.754/resnet50best.pth.tar"
+lr_scheduler: CosineAnnealingLR
+optimizer: SGD
+sgd {
+ weight_decay: 9.999999747378752e-05
+ momentum: 0.8999999761581421
+}
+distill {
+ teacher_model {
+ arch: "resnet50d"
+ source: Timm
+ }
+ kl_divergence {
+ temperature: 1.0
+ reduction: "batchmean"
+ loss_weight: 0.699999988079071
+ }
+}
+quantization {
+ quantize: true
+ quantize_fx: true
+ post_training_quantize: false
+ backend: TORCH_FBGEMM
+ num_observer_update_epochs: 4
+ num_batch_norm_update_epochs: 99999
+ activation_quantization_observer {
+ quantization_method: "minmax"
+ per_channel: false
+ symmetric: true
+ reduce_range: false
+ dtype: "quint8"
+ nbits: 4
+ fake_method: "lsq"
+ layers_restrict_to_8bit: "conv1,fc"
+ }
+ weight_quantization_observer {
+ quantization_method: "minmax"
+ per_channel: false
+ symmetric: true
+ reduce_range: false
+ dtype: "qint8"
+ nbits: 3
+ fake_method: "lsq"
+ layers_restrict_to_8bit: "conv1,fc"
+ }
+}
\ No newline at end of file
diff --git a/run_cli.sh b/run_cli.sh
old mode 100644
new mode 100755
diff --git a/src/model_optimizer/proto/model_optimizer_torch.proto b/src/model_optimizer/proto/model_optimizer_torch.proto
index d8fb55d..77a5371 100644
--- a/src/model_optimizer/proto/model_optimizer_torch.proto
+++ b/src/model_optimizer/proto/model_optimizer_torch.proto
@@ -292,7 +292,7 @@ message ExponentialLRParam{
message CyclicLRParam {
- // 在使用时,原始的CLR是按照 batch iteration 的更新 lr 的,本项目中为了和之前的几个LR统一,
+ // 在使用时,原始的 CLR 是按照 batch iteration 的更新 lr 的,本项目中为了和之前的几个 LR 统一,
// 使用了 epoch iteration 来进行更新 lr
// base_lr (float or list): Initial learning rate which is the
// lower boundary in the cycle for each parameter group.
diff --git a/src/model_optimizer/utils/__init__.py b/src/model_optimizer/utils/__init__.py
index 951ed92..4d36e25 100644
--- a/src/model_optimizer/utils/__init__.py
+++ b/src/model_optimizer/utils/__init__.py
@@ -9,6 +9,7 @@
import copy
import yaml
import torch
+import torch.fx
from torch import nn