-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexport.py
More file actions
48 lines (34 loc) · 1.2 KB
/
export.py
File metadata and controls
48 lines (34 loc) · 1.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sys
import argparse
import torch
from storage.util import load
def main():
parser = argparse.ArgumentParser(
description="Utility to export model to ONNX")
parser.add_argument("--model-path", '-p', type=str, required=True,
help="path to the trained model")
parser.add_argument("--output", '-o', type=str, default=None,
help="output path, default is <model_path>.onnx")
args = parser.parse_args()
arch, model, class_names = load(args.model_path, inference=True, device='cpu')
model.eval()
if args.output is None:
output_path = args.model_path + ".onnx"
else:
output_path = args.output
if not output_path.endswith(".onnx"):
output_path += ".onnx"
dummy_input = torch.randn(1, 3, arch.image_size, arch.image_size).to(
dtype=torch.float32)
model.to(dtype=torch.float32)
torch.onnx.export(model, dummy_input, output_path,
input_names=arch.input_names(),
output_names=arch.output_names(),
opset_version=11,
do_constant_folding=True,
keep_initializers_as_inputs=True)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(0)