nn.Module是在pytorch使用非常广泛的类,搭建网络基本都需要用到这个。
当我们搭建自己的网络时,可以继承官方写好的nn.Module模块,为什么要用这个呢?好处如下:
nn.Module作用
1.可以提供一些现成的基本模块比如:
Linear、ReLU、Sigmoid、Conv2d、Dropout
不用自己一个一个的写这些函数了,这也是为什么我们用框架的原因之一吧。
2. 容器
比如我们经常用到的 nn.Sequential(),顾名思义,将网络模块封装在一个容器中,可以方面网络搭建
如下面一个例子:
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(1*14*14, 10))
def forward(self, x):
return self.net(x)
3.参数管理
参数名字可以自动生成(想想如果自己去命名,百万参数的网络没法搭建),然后这些参数都可以传到优化器里面去优化
4. 所有modules的节点 孩子节点都是直系的
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__()
self.net = nn.Linear(4, 3)
def forward(self, x):
return self.net(x)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(BasicNet(),
nn.ReLU(),
nn.Linear(3, 2))
def forward(self, x):
return self.net(x)
比如上面的代码,我们可以看出Net网络中有5个孩子节点:nn.Sequential,BasicNet, nn.ReLU,nn.Linear,BasicNet里面的nn.Linear
5.to(device)
nn.Module还有一个功能是将某个网络所有成员、函数、操作都搬移到GPU上面。
采用代码如下:
device = torch.device('cuda')
net = Net()
net.to(device)
上面device代表当前的设备是GPU还是CPU,需要注意的是为什么我们不写
net = net.to(device)
其实效果是一样的,采用nn.Module模块,net加上.to(device),还是net。如果是变量则不是一样的,即如果对于tensor bias,那么bias和bias.to(device)不是一样的,则需要重新命名。
6.保存和加载模型
可以方面我们保存和加载模型
加载模型:
net.load_state_dict(torch.load('ckpt.mdl'))
保存模型:
torch.save(net.state_dict(), 'ckpt.mdl')
7.训练/测试
方便训练和测试进行切换,为什么?因为网络中Dropout和BN在训练和测试是不一样的,需要切换
如果不切换效果就会很差,这个是容易犯的一个错误。
net.train()
net.eval()
8.实现自己的类
官方给的模块还是基础操作的,如果自己要搭建复杂的操作也容易实现,一个典型的例子就是可以自己设计一个新的损失函数。
下面给出将tensor压平的例子(nn.Module没有这个操作):
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, input):
return input.view(input.size(0), -1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2, 2),
Flatten(), #自己定义的
nn.Linear(1*14*14, 10))
def forward(self, x):
return self.net(x)
Flatten压平的操作则是我们自己构建的类,可以方便后续BasicNet类使用,注意nn.Sequential里面必须是类。
且在上面例子中Flatten不需要接任何参数。
8.1举一个自己写的线性层的例子
class MyLinear(nn.Module):
def __init__(self, inp, outp):
super(MyLinear, self).__init__()
# requires_grad = True
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self, x):
x = x @ self.w.t() + self.b
return x
在上面自己写的线性层
y
=
w
x
+
b
y=wx+b
y=wx+b,可以看出
w
w
w和
b
b
b必须要使用nn.Parameter这个模块。原因是只用加上了nn.Parameter后,
w
w
w和
b
b
b才可以用优化器SGD等进行优化。
如果不写nn.Parameter那么则需要写requires_grad = True,还要自己写优化器,就很麻烦。用了Parameter可以方便我们优化网络:
model = MyLinear.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

本文详细介绍了PyTorch中nn.Module类的重要作用,包括提供现成的模块,如Linear和Conv2d,作为容器组织网络,管理参数,方便在GPU上运行,以及如何保存和加载模型。nn.Module还允许在训练和测试模式间切换,并支持自定义模块,如创建自定义线性层,便于网络优化和复杂操作的实现。


被折叠的 条评论
为什么被折叠?



