比如说加载训练好的cnn模型,我把下面这段复制到新的模型中,再加载就没有报错了,不知道对不对。
class CNN(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
tsne_feature = x
x = self.flatten(x)
return self.linear(x), tsne_feature
cnn=torch.load('model.pkl')
Net类继承nn.Module,super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始
化。而且是用nn.Module的初始化方法来初始化继承的属性。
另外记录保存模型和加载模型的代码:
torch.save(cnn, 'model.pkl') # 保存整个模型
torch.save(cnn.state_dict(), 'model_params.pth') # 只保存网络中的参数
new_model = torch.load('./data/model.pkl') # 加载模型
torch.load_state_dict(torch.load('model_params.pth'))#提取所有的参数
文章讨论了如何在PyTorch中加载和保存卷积神经网络(CNN)模型。通过创建一个名为CNN的类并定义forward函数,然后使用torch.load加载模型。同时,介绍了保存模型参数的方法,以及如何通过state_dict()加载这些参数。

2650

被折叠的 条评论
为什么被折叠?



