PyTorch的torch.onnx.export函数允许导出包含多个输出的模型,只需确保模型的forward方法返回一个元组或列表,包含所有输出。
1. 示例代码
import torch
import torch.nn as nn
# 钩子函数,用于打印每一层的输出尺寸
def hook_fn(module, input, output):
# 过滤掉 nn.Sequential 层
if not isinstance(output, tuple):
print(f"{
module.__class__.__name__} : {
output.shape}")
# 定义一个简单的模型,具有多个输出
class MultiOutputModel(nn.Module):
def __init__(self):
super(MultiOutputModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3,


2万+

被折叠的 条评论
为什么被折叠?



