Skip to content

How to export onnx model in save_memory=True? #18

@BearCooike

Description

@BearCooike

We are trying to convert Revcol to TensorRT format, but when converting to ONNX, we found that when using save_memory=True, the conversion does not work properly.
Here is our conversion test code:

import torch
from models.revcol import *
model = revcol_tiny(save_memory=True, inter_supv=False, drop_path = 0.1, num_classes=10, kernel_size = 3)

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

x = torch.zeros(1, 3, 224, 224)
torch.onnx.export(model, x, './weights/revcol_tiny.onnx', verbose=False, opset_version=17,
                        training=torch.onnx.TrainingMode.EVAL,
                        do_constant_folding=True,
                        input_names=['images'],
                        output_names=['output'],
                        dynamic_axes=None) 

When save_memory=True, the following error occurs:

File [d:\SoftWare\anaconda3\envs\torch\lib\site-packages\torch\onnx\utils.py:506](file:///D:/SoftWare/anaconda3/envs/torch/lib/site-packages/torch/onnx/utils.py:506), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
...
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

RuntimeError: invalid unordered_map<K, T> key

If you add the following code, the export will work, but you should not be able to take advantage of the low memory footprint of Reversible Net.

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

Is there any relevant solution?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions