1. 为什么我们需要ONNX-Go?从一次真实的项目实训说起
去年带山东大学软件学院21级的同学做项目实训,有个场景我印象特别深。一个小组想用Go语言写一个微服务,核心功能是图片分类。他们用Python的PyTorch训练了一个挺不错的ResNet模型,准确率很高,但在集成时卡住了。Go程序怎么调用这个.pt文件?难道要在Go里再实现一遍神经网络的前向传播?或者用CGO去调Python解释器?前者工程浩大,后者部署和维护简直是噩梦,依赖复杂,性能也堪忧。
就在大家一筹莫展的时候,我们发现了ONNX和ONNX-Go这个组合拳,一下子把路走通了。简单来说,ONNX就像一个“深度学习模型的通用翻译官”。无论你的模型来自PyTorch、TensorFlow还是其他主流框架,都可以先把它“翻译”成ONNX这个标准格式。然后,ONNX-Go这个库,就是Go语言世界里专门读取和执行这个“标准格式模型”的引擎。
这解决了我们的大问题:训练和部署的语言解耦。你可以继续用Python生态里丰富的工具和库(如PyTorch, TensorFlow)高效地训练和调试模型,享受其灵活性。一旦模型定型,就把它导出为ONNX文件。在部署侧,你可以利用Go语言的优势——编译成单一可执行文件、部署简单、并发性能好、内存开销相对可控——来构建你的Web服务、命令行工具或嵌入式应用。ONNX-Go就是连接这两个世界的稳固桥梁。
所以,如果你也面临类似场景:用Go写应用,但又需要深度学习的智能能力,那么从模型训练到ONNX转换,再到用ONNX-Go集成部署,这条路径非常值得一试。它不是什么纸上谈兵的理论,而是我们实训中验证过的、能跑通生产流程的实战方案。
2. 第一步:训练你的模型并导出为ONNX格式
模型训练本身是个大学问,这里我们假设你已经用PyTorch完成了一个简单的图像分类模型训练。我们的重点,是如何把它“打包”成ONNX这个通用格式。
2.1 一个简单的PyTorch模型示例
为了后续演示清晰,我们构建一个极简的模型。这个模型可能不实用,但能让你看清每一步在做什么。
import torch
import torch.nn as nn
# 定义一个非常简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 128) # 假设输入是32x32的图片,经过两次池化后是8x8
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 实例化模型并设置为评估模式
model = SimpleCNN(num_classes=10)
model.eval() # 这是关键!导出ONNX前必须设为eval模式
注意:model.eval() 这一步至关重要。它会让模型中的某些层(如Dropout、BatchNorm)切换到推理模式,确保导出的ONNX模型行为与推理时一致。很多同学导出后精度不对,第一步就先检查这里。
2.2 执行ONNX导出:关键参数详解
导出ONNX的核心函数是 torch.onnx.export。它有几个参数容易让人困惑,我结合踩过的坑给你讲清楚。
import torch.onnx
# 1. 创建一个符合模型输入要求的虚拟输入(dummy input)
# 这个张量的形状必须和你的模型forward函数期望的完全一致
# (batch_size, channels, height, width)
dummy_input = torch.randn(1, 3, 32, 32)
# 2. 指定导出的ONNX文件名
onnx_model_path = "simple_cnn.onnx"
# 3. 执行导出
torch.onnx.export(
model, # 要导出的模型
dummy_input, # 模型输入示例
onnx_model_path, # 输出文件路径
export_params=True, # 是否将模型参数(权重)一起导出。必须为True!
opset_version=13, # ONNX算子集版本。建议用11以上,兼容性更好。
do_constant_folding=True, # 是否进行常量折叠优化。建议True,可以简化计算图。
input_names=['input'], # 输入节点的名称,后续在Go中会用到
output_names=['output'], # 输出节点的名称
dynamic_axes={ # **处理动态形状的关键!**
'input': {0: 'batch_size'}, # 第0维(批大小)是动态的
'output': {0: 'batch_size'}
}
)
print(f"模型已成功导出至: {onnx_model_path}")
这里我特别想强调 dynamic_axes 参数。在实训初期,我们导出的模型只能处理固定 batch_size=1 的输入,这在实际的Web服务中很不灵活。通过设置 dynamic_axes,我们告诉ONNX:“模型的第一个维度(批大小)是可变的”。这样导出的模型,在Go端就能灵活处理单张图片或一批图片的推理了。这是让模型具备实用性的关键一步。
2.3 验证导出的ONNX模型
导出完别急着走,先用Python快速验证一下模型是否正确。这能帮你提前发现大部分问题,避免在Go端调试时抓瞎。
import onnx
import onnxruntime as ort
import numpy as np
# 1. 检查模型格式是否有效
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
print("ONNX模型格式检查通过!")
# 2. 使用ONNX Runtime进行推理,验证结果一致性
ort_session = ort.InferenceSession(onnx_model_path)
# 准备与导出时相同格式的输入
test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
# 使用PyTorch原模型推理(作为基准)
with torch.no_grad():
torch_output = model(torch.from_numpy(test_input


96

被折叠的 条评论
为什么被折叠?



