pytorch保存onnx模型
因为一些原因,需要用pytorch去创建、训练和保存模型。pytorch保存的模型通常为pth、pt、pkl的格式,但这种类型的模型不能在其他框架(tensorflow)下直接加载,因此需要将模型保存为其他格式的。在网上进行相应的学习后,总结出一下两点:
- pytorch可以直接将模型保存为onnx的,并且可以通过onnx转换为其他格式的模型(pb);
- pytorch也可以直接将模型保存为caffemodel,但是需要一定的代码量去实现;
环境配置:
| 操作系统 | ubuntu18.04 |
|---|---|
| GPU型号 | GTX960 |
| cuda版本 | 10.0 |
| pytorch版本 | 1.2.0 |
实现代码:
import torch
import torch.onnx
from torch.autograd import Variable
x = Variable(torch.randn(1, 3, 32, 32)).cuda()
torch_out = torch.onnx.export(model, x,
"test.onnx",
export_params=True,
verbose=True)
API export说明:
export(model, args, f, export_params=True, verbose=False, training=False,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None)
参数说明:
Export a model into ONNX format. This exporter runs your model
once in order to get a trace of its execution to be exported;
at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
See also: :ref:`onnx-export`
Arguments:
model (torch.nn.Module): the model to be exported.
args

本文详细介绍如何使用PyTorch将模型保存为ONNX格式,包括环境配置、转换代码及API参数说明,适用于跨框架模型迁移。

2万+

被折叠的 条评论
为什么被折叠?



