from tinynn.graph.tracer import import_patcher
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer
with import_patcher():
from transformers import ViTForImageClassification
device = 'cuda'
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class ViTWrapper(nn.Module):
def __init__(self, vit):
super().__init__()
self.vit = vit
def forward(self, x):
return self.vit(x).logits
model_url = 'google/vit-base-patch16-224'
with import_patcher():
model = ViTForImageClassification.from_pretrained(model_url)
Vit = ViTWrapper(model)
################ QAT SET #######################
dummy = torch.rand([1,3,224,224])
mean = (0.485 + 0.456 + 0.406) / 3 * 255
std = (0.229 + 0.224 + 0.225) / 3 * 255
config={'asymmetric': True, 'per_tensor': False, 'backend':'fbgemm',"rewrite_graph": False,'quantized_input_stats':[(float(mean), float(std))]}
quantizer = QATQuantizer(Vit, dummy, work_dir='out',config=config)
Vit = quantizer.quantize()
Vit.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
Vit = torch.quantization.prepare_qat(Vit)
Vit.cuda()
#####Skip training Step#####
##### Convert tflite Step#####
if((epoch) % 1 == 0):
with torch.no_grad():
device = 'cpu'
Vit.eval()
Vit.to(device)
Vit = quantizer.convert(Vit)
torch.backends.quantized.engine = quantizer.backend
converter = TFLiteConverter(Vit, dummy, tflite_path='./qat_model.tflite',fuse_quant_dequant=True)
converter.convert()
Traceback (most recent call last):
File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 100, in <module>
converter.convert()
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 512, in convert
self.init_jit_graph()
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 237, in init_jit_graph
script = torch.jit.trace(self.model, self.dummy_input)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1000, in trace
traced_func = _trace_impl(
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 695, in _trace_impl
return trace_module(
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1275, in trace_module
module._c._create_method_from_trace(
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 39, in forward
return self.vit(x).logits
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 789, in forward
outputs = self.vit(
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
result = self.forward(*input, **kwargs)
File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 571, in forward
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
AttributeError: 'function' object has no attribute 'dtype
It can train the ViT model from the Hugging Face transformer,
but when converting to tflite model it appear an error message that I can't solve it.
The following are the tinynn setting and the error message
Transformers version is 4.26.0
The error message: