From c21d344952ae545bc5a65e50501121baf4ddda8b Mon Sep 17 00:00:00 2001 From: lyg95 Date: Wed, 7 Jun 2023 16:57:29 +0800 Subject: [PATCH] Add the config file of mixed precision quantization for ResNet50 Signed-off-by: lyg95 --- README.md | 48 ++- .../main_quantization_lsq_distillation.py | 307 ++++++++++++++++++ ...ization_lsq_first_last_layer_int8.prototxt | 2 +- ...first_last_layer_int8_load_weight.prototxt | 2 +- ...ast_layer_int8_load_weight_resume.prototxt | 2 +- ..._first_last_layer_int8_per_tensor.prototxt | 2 +- ...ayer_int8_per_tensor_distillation.prototxt | 2 +- ...layer_int8_per_tensor_load_weight.prototxt | 2 +- .../efficientnetb0_quantization_ptq.prototxt | 2 +- ...onsinelr_distillation_load_weight.prototxt | 2 +- ...esnet50_autoslim_quantization_ptq.prototxt | 2 +- .../resnet/resnet50_distillation.prototxt | 45 +++ ...r_distillation_load_weight_80.754.prototxt | 68 ++++ run_cli.sh | 0 .../proto/model_optimizer_torch.proto | 2 +- src/model_optimizer/utils/__init__.py | 1 + 16 files changed, 472 insertions(+), 17 deletions(-) create mode 100755 examples/classifier_imagenet/main_quantization_lsq_distillation.py create mode 100644 examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt create mode 100644 examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt mode change 100644 => 100755 run_cli.sh diff --git a/README.md b/README.md index 00a8249..e48e6d4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Model Optimizer -[![Build Status](https://dev.azure.com/Adlik/GitHub/_apis/build/status/Adlik.model_optimizer?branchName=main)](https://dev.azure.com/Adlik/GitHub/_build/results?buildId=3472&view=results) +[![Build Status](https://dev.azure.com/Adlik/GitHub/_apis/build/status/Adlik.model_optimizer?branchName=main)](https://dev.azure.com/Adlik/GitHub/_build/latest?definitionId=2&branchName=main) [![Bors enabled](https://bors.tech/images/badge_small.svg)](https://app.bors.tech/repositories/65566) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) @@ -22,9 +22,17 @@ of the models. Note that activations use the same observer as weights unless oth | resnet50 | 76.13 | 75.580 | 75.612 | 75.99 | | mobilenetv2 | 71.878 | 70.730(act=percentile) | 70.816 | 71.11 | +For quantization, we explored low-bit model quantization, where weights are quantized to 3 bits and activations are +quantized to 4 bits, using the lsq algorithm to quantize resnet50. Compared with the FP32 model, the quantized resnet50 +has no accuracy loss and has higher accuracy than the original model. Our quantized resnet50 model achieves 77.34% +accuracy on the ImageNet dataset. Here we supply the quantized model in +[model_zoo](https://github.com/Adlik/model_zoo) for testing. Besides, we submit our quantized model to +[paperswithcode](https://paperswithcode.com/sota/quantization-on-imagenet?tag_filter=447), this is the current +state-of-art result in low-bit quantization. + For AutoSlim, we give the pruning effect of the resnet50 model on the ImageNet dataset. -| FLOPs(G) |Params(M) |Size(MB)| Top-1| Acc |Input Size| +| Model | FLOPs(G) |Params(M) |Size(MB)| Top-1 Acc |Input Size| |:---:| :---: | :---: | :---: | :---: | :---: | |ResNet5 |4.12 |25.56 |98 |77.39% |224| |ResNet-50 0.75× |2.35 |14.77 |57 |75.87% |224| @@ -37,8 +45,8 @@ For AutoSlim, we give the pruning effect of the resnet50 model on the ImageNet d The following table shows the effect of AutoSlim on YOLOv5m backbone pruning. -| FLOPs(G) |Params(M) |Size(MB)|mAPval 0.5:0.95| Input Size| -| :---: | :---: | :---: | :---: | :---: | +| Model | FLOPs(G) |Params(M) |Size(MB)|mAPval 0.5:0.95| Input Size| +| :---: | :---: | :---: | :---: | :---: | :---: | |YOLOv5m |24.5 |21.2 |81 |44.4 |640| |AutoSlim-YOLOv5m | 16.7(-31.8%)| 17.8(-16%)| 69(-14.8%)| 42.0(-2.4)| 640| @@ -108,7 +116,7 @@ Refer to the paper [Distilling the Knowledge in a Neural Network](https://arxiv. ### 1.4 Pruning AutoSlim can prune the model automatically, which it can achieve better model accuracy under -limited resource conditions (such as FLOPs, latency, memory footprint, or model size)。In AutoSlim, +limited resource conditions (such as FLOPs, latency, memory footprint, or model size). In AutoSlim, it can be divided into several steps. The first step is to train a slimmable model for a few epochs (e.g., 10% to 20% of full training epochs) to quickly get a benchmark performance estimator; Then we evaluate the trained slimmable model and greedily slim the layer with minimal accuracy drop on a @@ -157,14 +165,14 @@ python -m pip install -r requirements.txt There are two installation methods. -1、Python wheel installation +- Python wheel installation ```sh cd model_optimizer python setup.py install ``` -2、Developer mode installation +- Developer mode installation ```sh chmod +x *.sh @@ -358,6 +366,32 @@ examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_search.prototxt examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_retrain_100epochs_lr0.4_decay5e05_momentum0.9_ls0.1.prototxt ``` +### 3.5 Low-bit quantization + +Here is a detailed introduction to the steps to reproduce the quantitative results of our model. The quantization process +can be divided into two steps. + +(1) model training + +First of all, we need a fully trained model as the base model for quantization. Here you can use the following +configuration file to train a higher-precision model. + +```sh +examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt +``` + +When training the model, we will use resnet50d as the teacher model to distill the student model. + +(2) model quantization + +When quantizing, load a trained model, then use the LSQ algorithm to quantize the model. To improve +the performance of the quantized model, we will use the distillation method in the quantization process. The following +configuration file is the configuration of our model quantization. + +```sh +examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt +``` + ## Acknowledgement We would like to thank them for their excellent open-source work. diff --git a/examples/classifier_imagenet/main_quantization_lsq_distillation.py b/examples/classifier_imagenet/main_quantization_lsq_distillation.py new file mode 100755 index 0000000..4220364 --- /dev/null +++ b/examples/classifier_imagenet/main_quantization_lsq_distillation.py @@ -0,0 +1,307 @@ +# Copyright 2023 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +quantization train with PTQ or QAT +""" +import time +import copy +import datetime +import os +from torch.backends import cudnn +import torch.distributed as dist +from torch import nn +import torch +from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx, fuse_fx +from model_optimizer.core import (get_base_parser, get_hyperparam, get_freer_gpu, main_s1_set_seed, + main_s2_start_worker, display_model, process_model, resume_model, + distributed_model, get_optimizer, get_summary_writer, + get_model_info, get_lr_scheduler, validate, train, save_checkpoint, + save_torchscript_model, save_file_on_master, save_yaml_file_on_master, + print_on_master) +from model_optimizer.proto import model_optimizer_torch_pb2 as eppb +from model_optimizer.quantizer import (get_qconfig, get_lsq_qconfig, get_model_quantize_bit_config, + clip_model_weight_in_quant_min_max) +from model_optimizer.losses import KLDivergence +from model_optimizer.datasets import DataloaderFactory +from model_optimizer.models import get_model_from_source, get_teacher_model + +best_acc1 = 0 + + +def main(): + """ + main process + Returns: + + """ + parser = get_base_parser() + args = parser.parse_args() + hp = get_hyperparam(args) + if hp.gpu_id == eppb.GPU.ANY: + args.gpu = get_freer_gpu() + elif hp.gpu_id == eppb.GPU.NONE: + args.gpu = None # TODO: test + + print("Start training") + start_time = time.time() + main_s1_set_seed(hp) + main_s2_start_worker(main_worker, args, hp) + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f'Training time {total_time_str}') + + +def main_worker(gpu, ngpus_per_node, args): # pylint: disable=too-many-branches,too-many-statements + """ + main worker on per gpu + Args: + gpu: + ngpus_per_node: + args: + + Returns: + + """ + args.gpu = gpu + if args.gpu is not None: + print(f"Use GPU: {args.gpu} for training") + args.hp = get_hyperparam(args) + + if args.hp.quantization.post_training_quantize and args.distributed: + raise RuntimeError("Post training quantization should not be performed " + "on distributed mode") + + if args.distributed: + if args.hp.multi_gpu.dist_url == "env://" and args.hp.multi_gpu.rank == -1: + args.hp.multi_gpu.rank = int(os.environ["RANK"]) + if args.hp.multi_gpu.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.hp.multi_gpu.rank = args.hp.multi_gpu.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.hp.multi_gpu.dist_backend, init_method=args.hp.multi_gpu.dist_url, + world_size=args.world_size, rank=args.hp.multi_gpu.rank) + # create model + if args.hp.pretrained: + print(f"=> using pre-trained model '{args.hp.arch}'") + else: + print(f"=> creating model '{args.hp.arch}'") + load_quantize_model = args.hp.quantization.quantize and (not args.hp.quantization.quantize_fx) + model = get_model_from_source(args.hp.arch, args.hp.model_source, args.hp.pretrained, args.hp.width_mult, + args.hp.depth_mult, load_quantize_model, args.hp.is_subnet, + args.hp.auto_slim.channel_config_path if args.hp.HasField('auto_slim') else None) + # Teacher model + if args.hp.distill.teacher_model.arch: + teacher_model = get_teacher_model(args.hp.distill.teacher_model.arch, args.hp.distill.teacher_model.source) + else: + teacher_model = None + if args.gpu == 0 and args.hp.multi_gpu.rank in [-1, 0]: + print(args) + print('model:\n=========\n') + display_model(model) + + process_model(model, args) + + if args.hp.quantization.quantize: + torch.backends.quantized.engine = args.hp.quantization.backend + + if args.hp.quantization.quantize and not args.hp.quantization.post_training_quantize: + if args.hp.quantization.quantize_fx: + qconfig_dict = get_lsq_qconfig(model, args.hp.quantization, args.hp.quantization.backend) + model = prepare_qat_fx(model, qconfig_dict) + else: + qconfig = get_qconfig(args.hp.quantization, args.hp.quantization.backend) + model.fuse_model() + model.qconfig = qconfig + torch.quantization.prepare_qat(model, inplace=True) + + df_calibrate = DataloaderFactory(args) + calibrate_train_loader, _, _ = df_calibrate.product_train_val_loader( + df_calibrate.imagenet2012, + num_batches=1, + use_val_trans=True, + batch_same_data=True) + # simple_train_loader = df_simple.product_train_loader(df_simple.imagenet2012) + model.apply(torch.quantization.enable_observer) + model.apply(torch.quantization.disable_fake_quant) + + print('init scale for lsq') + model.apply(torch.quantization.enable_observer) + with torch.no_grad(): + for images, _ in calibrate_train_loader: + model(images) + del calibrate_train_loader + del df_calibrate + torch.distributed.barrier() + # parallel and multi-gpu + model = distributed_model(model, ngpus_per_node, args) + if args.distributed: + model_without_ddp = model.module + else: + model_without_ddp = model + + if args.hp.distill.teacher_model.arch: + teacher_model = distributed_model(teacher_model, ngpus_per_node, args) + # define distillation loss function (distill_criterion) + distill_criterion = KLDivergence(args.hp.distill.kl_divergence.temperature, + loss_weight=args.hp.distill.kl_divergence.loss_weight) + + # define loss function (criterion) + criterion = nn.CrossEntropyLoss(label_smoothing=args.hp.label_smoothing).cuda(args.gpu) + + optimizer = get_optimizer(model, args) + cudnn.benchmark = True + + if args.hp.quantization.quantize and args.hp.quantization.post_training_quantize: + + dataload_factory = DataloaderFactory(args) + data_loader_calibration, val_loader, train_sampler = dataload_factory.product_train_val_loader( + dataload_factory.imagenet2012, + args.hp.quantization.num_calibration_batches, use_val_trans=True) + model.eval() + qconfig = get_qconfig(args.hp.quantization, args.hp.quantization.backend) + if args.hp.quantization.quantize_fx: + qconfig_dict = {"": qconfig} + model = prepare_fx(model, qconfig_dict) + else: + model.fuse_model() + model.qconfig = qconfig + torch.quantization.prepare(model, inplace=True) + validate(data_loader_calibration, model, criterion, args) + model.to(torch.device('cpu')) + writer = get_summary_writer(args, ngpus_per_node, model) + if args.hp.quantization.quantize_fx: + if args.hp.save_jit_trace: + bit_config = get_model_quantize_bit_config(model) + save_yaml_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_bit_config.yaml', bit_config) + model = convert_fx(model) + model = fuse_fx(model) + if args.hp.save_jit_trace: + clip_model_weight_in_quant_min_max(model, bit_config) + else: + torch.quantization.convert(model, inplace=True) + if args.hp.save_jit_trace: + filename = 'ptq_quant.jit' + save_torchscript_model(model, val_loader, prefix=f'{args.log_name}/{args.arch}', + filename=filename) + save_file_on_master(args, f'{args.log_name}/{args.arch}_quantized_by_ptq.txt', str(model)) + acc1, acc5 = validate(val_loader, model, criterion, args, device='cpu') + if writer is not None: + writer.add_scalar('val/acc1', acc1, 0) + writer.add_scalar('val/acc5', acc5, 0) + writer.close() + return + + dataload_factory = DataloaderFactory(args) + train_loader, val_loader, train_sampler = dataload_factory.product_train_val_loader(dataload_factory.imagenet2012) + writer = get_summary_writer(args, ngpus_per_node, model) + if args.hp.evaluate: + if writer is not None: + get_model_info(copy.deepcopy(model_without_ddp), args, val_loader) + args.batch_num = len(train_loader) + + lr_scheduler = get_lr_scheduler(optimizer, args) + + resume_model(model_without_ddp, args, optimizer, lr_scheduler) + + global best_acc1 # pylint: disable=global-statement + best_acc1 = args.best_acc1 + if args.hp.evaluate: + validate(val_loader, model, criterion, args) + return + + if args.hp.amp: + # Automatic mixed precision training + scaler = torch.cuda.amp.GradScaler() + else: + scaler = None + + if args.hp.quantization.quantize: + model.apply(torch.quantization.disable_observer) + model.apply(torch.quantization.disable_fake_quant) + acc1, acc5 = validate(val_loader, model, criterion, args) + print_on_master(args, f'validate after calibration: acc1: {acc1}, acc5: {acc5}') + model.apply(torch.quantization.enable_fake_quant) + + for epoch in range(args.start_epoch, args.hp.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + # train for one epoch + epoch_start_time = time.time() + train(train_loader, model, criterion, optimizer, epoch, args, writer, scaler, + teacher_model=teacher_model, distill_criterion=distill_criterion) + lr_scheduler.step() + epoch_total_time = time.time() - epoch_start_time + total_time_str = str(datetime.timedelta(seconds=int(epoch_total_time))) + + if args.hp.multi_gpu.rank in [-1, 0]: + print(f'Epoch[{epoch + 1}/{args.hp.epochs}] total time {total_time_str}') + with torch.no_grad(): + if args.hp.quantization.quantize: + if epoch >= args.hp.quantization.num_observer_update_epochs: + print('Disabling observer for subseq epochs, epoch = ', epoch) + model.apply(torch.quantization.disable_observer) + if epoch >= args.hp.quantization.num_batch_norm_update_epochs: + print('Freezing BN for subseq epochs, epoch = ', epoch) + model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + # evaluate on validation set + if args.hp.quantization.quantize: + print('Evaluate QAT model') + acc1, acc5 = validate(val_loader, model, criterion, args) + if writer is not None: + writer.add_scalar('val/acc1', acc1, epoch) + writer.add_scalar('val/acc5', acc5, epoch) + writer.add_scalar('val/lr', optimizer.param_groups[0]['lr'], epoch) + # remember best acc@1 and save checkpoint + + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + if is_best: + print_on_master(args, f'** best at epoch: {epoch + 1}, acc1: {acc1}, acc5: {acc5}') + if writer is not None: + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model_without_ddp.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + }, is_best, prefix=f'{args.log_name}/{args.arch}') + with torch.no_grad(): + if args.hp.quantization.quantize: + save_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_by_qat_prepared.txt', + str(model_without_ddp)) + quantized_eval_model = copy.deepcopy(model_without_ddp) + quantized_eval_model.eval() + quantized_eval_model.to(torch.device('cpu')) + if args.hp.quantization.quantize_fx: + if args.hp.save_jit_trace: + bit_config = get_model_quantize_bit_config(quantized_eval_model) + save_yaml_file_on_master(args, f'{args.log_name}/{args.arch}_quantize_bit_config.yaml', + bit_config) + quantized_eval_model = convert_fx(quantized_eval_model) + if args.hp.save_jit_trace: + clip_model_weight_in_quant_min_max(quantized_eval_model, bit_config) + else: + torch.quantization.convert(quantized_eval_model, inplace=True) + save_file_on_master(args, f'{args.log_name}/{args.arch}_quantized_by_qat.txt', + str(quantized_eval_model)) + print('Evaluate Quantized model') + validate(val_loader, quantized_eval_model, criterion, args, device='cpu') + if is_best and args.hp.save_jit_trace: + save_file_on_master(args, f'{args.log_name}/{args.arch}_save_jit_log.txt', + f'best at epoch: {epoch+1}, acc1: {acc1}, acc5: {acc5}') + if args.hp.quantization.quantize: + save_jit_model = quantized_eval_model + filename = 'best_qat_quant.jit' + else: + save_jit_model = model + filename = 'best.jit' + save_torchscript_model(save_jit_model, val_loader, prefix=f'{args.log_name}/{args.arch}', + filename=filename) + del quantized_eval_model + if writer is not None: + writer.close() + + +if __name__ == '__main__': + main() diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt index 4fae1f6..9149d00 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8.prototxt @@ -41,7 +41,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt index c49f264..d6c389e 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight.prototxt @@ -42,7 +42,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt index e93f6cf..195a001 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_load_weight_resume.prototxt @@ -43,7 +43,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt index cfa65ea..1486632 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor.prototxt @@ -41,7 +41,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt index 6c0e498..093bf70 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_distillation.prototxt @@ -46,7 +46,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt index 9b8f6cf..3667da7 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_lsq_first_last_layer_int8_per_tensor_load_weight.prototxt @@ -42,7 +42,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt index a43f5d7..99e0f05 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_ptq.prototxt @@ -31,7 +31,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: true - backend: "fbgemm" + backend: TORCH_FBGEMM num_calibration_batches: 120 activation_quantization_observer { quantization_method: "quantization_error" diff --git a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt index 56a99b2..573a233 100644 --- a/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt +++ b/examples/classifier_imagenet/prototxt/efficientnet/efficientnetb0_quantization_qat_8bit_first_last_layer_int8_per_channel_consinelr_distillation_load_weight.prototxt @@ -47,7 +47,7 @@ quantization { quantize: true quantize_fx: true post_training_quantize: false - backend: "fbgemm" + backend: TORCH_FBGEMM num_observer_update_epochs: 4 num_batch_norm_update_epochs: 99999 activation_quantization_observer { diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt index 49703ec..02920fe 100644 --- a/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt +++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_autoslim_quantization_ptq.prototxt @@ -39,7 +39,7 @@ quantization { quantize: true quantize_fx: false post_training_quantize: true - backend: "fbgemm" + backend: TORCH_FBGEMM num_calibration_batches: 120 activation_quantization_observer { quantization_method: "quantization_error" diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt new file mode 100644 index 0000000..fbfa5dd --- /dev/null +++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_distillation.prototxt @@ -0,0 +1,45 @@ +main_file: "examples/classifier_imagenet/main_distillation.py" +arch: "resnet50" +model_source: TorchVision +log_name: "multi_gpu" +data: "/data/imagenet/imagenet-torch" +debug: false +lr: 0.0500000007450581 +epochs: 720 +batch_size: 128 +workers: 8 +print_freq: 50 +evaluate: false +pretrained: false +seed: 0 +gpu_id: ANY +multi_gpu { + world_size: 1 + rank: 0 + dist_url: "tcp://127.0.0.1:23457" + dist_backend: "nccl" + multiprocessing_distributed: true +} +val_batch_size: 256 +warmup { + lr_warmup_epochs: 5 + lr_warmup_decay: 0.009999999776482582 +} +lr_scheduler: CosineAnnealingLR +optimizer: SGD +sgd { + weight_decay: 9.999999747378752e-06 + momentum: 0.8999999761581421 +} +amp: true +distill { + teacher_model { + arch: "resnet50d" + source: Timm + } + kl_divergence { + temperature: 1.0 + reduction: "batchmean" + loss_weight: 0.699999988079071 + } +} diff --git a/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt b/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt new file mode 100644 index 0000000..4a76516 --- /dev/null +++ b/examples/classifier_imagenet/prototxt/resnet/resnet50_quantization_lsq_3w4a_first_last_layer_int8_per_tensor_distillation_load_weight_80.754.prototxt @@ -0,0 +1,68 @@ +main_file: "examples/classifier_imagenet/main_quantization_lsq_distillation.py" +arch: "resnet50" +model_source: TorchVision +log_name: "quantization_lsq_distillation_first_last_layer_int8_lr_0.001" +data: "/data/imagenet/imagenet-torch" +debug: false +lr: 0.0010000000474974513 +epochs: 50 +batch_size: 32 +workers: 8 +print_freq: 50 +evaluate: false +pretrained: false +seed: 0 +gpu_id: ANY +multi_gpu { + world_size: 1 + rank: 0 + dist_url: "tcp://127.0.0.1:23457" + dist_backend: "nccl" + multiprocessing_distributed: true +} +weight: "/root/work/model_optimizer_torch_dev_1.03/logger/resnet50_multi_gpu_202206211454_acc_80.754/resnet50best.pth.tar" +lr_scheduler: CosineAnnealingLR +optimizer: SGD +sgd { + weight_decay: 9.999999747378752e-05 + momentum: 0.8999999761581421 +} +distill { + teacher_model { + arch: "resnet50d" + source: Timm + } + kl_divergence { + temperature: 1.0 + reduction: "batchmean" + loss_weight: 0.699999988079071 + } +} +quantization { + quantize: true + quantize_fx: true + post_training_quantize: false + backend: TORCH_FBGEMM + num_observer_update_epochs: 4 + num_batch_norm_update_epochs: 99999 + activation_quantization_observer { + quantization_method: "minmax" + per_channel: false + symmetric: true + reduce_range: false + dtype: "quint8" + nbits: 4 + fake_method: "lsq" + layers_restrict_to_8bit: "conv1,fc" + } + weight_quantization_observer { + quantization_method: "minmax" + per_channel: false + symmetric: true + reduce_range: false + dtype: "qint8" + nbits: 3 + fake_method: "lsq" + layers_restrict_to_8bit: "conv1,fc" + } +} \ No newline at end of file diff --git a/run_cli.sh b/run_cli.sh old mode 100644 new mode 100755 diff --git a/src/model_optimizer/proto/model_optimizer_torch.proto b/src/model_optimizer/proto/model_optimizer_torch.proto index d8fb55d..77a5371 100644 --- a/src/model_optimizer/proto/model_optimizer_torch.proto +++ b/src/model_optimizer/proto/model_optimizer_torch.proto @@ -292,7 +292,7 @@ message ExponentialLRParam{ message CyclicLRParam { - // 在使用时,原始的CLR是按照 batch iteration 的更新 lr 的,本项目中为了和之前的几个LR统一, + // 在使用时,原始的 CLR 是按照 batch iteration 的更新 lr 的,本项目中为了和之前的几个 LR 统一, // 使用了 epoch iteration 来进行更新 lr // base_lr (float or list): Initial learning rate which is the // lower boundary in the cycle for each parameter group. diff --git a/src/model_optimizer/utils/__init__.py b/src/model_optimizer/utils/__init__.py index 951ed92..4d36e25 100644 --- a/src/model_optimizer/utils/__init__.py +++ b/src/model_optimizer/utils/__init__.py @@ -9,6 +9,7 @@ import copy import yaml import torch +import torch.fx from torch import nn