pytorch保存onnx模型

本文详细介绍如何使用PyTorch将模型保存为ONNX格式,包括环境配置、转换代码及API参数说明,适用于跨框架模型迁移。

pytorch保存onnx模型


因为一些原因,需要用pytorch去创建、训练和保存模型。pytorch保存的模型通常为pth、pt、pkl的格式,但这种类型的模型不能在其他框架(tensorflow)下直接加载,因此需要将模型保存为其他格式的。在网上进行相应的学习后,总结出一下两点:

  1. pytorch可以直接将模型保存为onnx的,并且可以通过onnx转换为其他格式的模型(pb);
  2. pytorch也可以直接将模型保存为caffemodel,但是需要一定的代码量去实现;

环境配置

操作系统 ubuntu18.04
GPU型号 GTX960
cuda版本 10.0
pytorch版本 1.2.0

实现代码

import torch
import torch.onnx
from torch.autograd import Variable

x = Variable(torch.randn(1, 3, 32, 32)).cuda()
torch_out = torch.onnx.export(model, x, 
                              "test.onnx",
                               export_params=True,
                               verbose=True)

API export说明

export(model, args, f, export_params=True, verbose=False, training=False,
           input_names=None, output_names=None, aten=False, export_raw_ir=False,
           operator_export_type=None, opset_version=None, _retain_param_name=True,
           do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None)

参数说明:

Export a model into ONNX format.  This exporter runs your model
    once in order to get a trace of its execution to be exported;
    at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
    See also: :ref:`onnx-export`
    Arguments:
        model (torch.nn.Module): the model to be exported.
        args 
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值