本文是基于成对低光图像学习自适应先验方案的代码讲解,文章讲解可看链接PairLLE。
1、原文概要
本文PairLIE 是一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,与一般方法的区别如下图所示:

常规的方法是(a),本文提出的方法是(b),可以看到本文的方法需要2张图来优化。
2、代码结构
代码整体结构如下:

核心代码模块包含模型结构、数据加载、训练流程3部分。
3 、核心代码模块
1. 模型结构
模型包含用于完成降噪恒等映射的N_net,以及预测反射图和光照图的R_net和L_net,代码在net/net.py中。
class L_net(nn.Module):
def __init__(self, num=64):
super(L_net, self).__init__()
self.L_net = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, 1, 3, 1, 0),
)
def forward(self, input):
return torch.sigmoid(self.L_net(input))
class R_net(nn.Module):
def __init__(self, num=64):
super(R_net, self).__init__()
self.R_net = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, 3, 3, 1, 0),
)
def forward(self, input):
return torch.sigmoid(self.R_net(input))
class N_net(nn.Module):
def __init__(self, num=64):
super(N_net, self).__init__()
self.N_net = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, num, 3, 1, 0),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(num, 3, 3, 1, 0),
)
def forward(self, input):
return torch.sigmoid(self.N_net(input))
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.L_net = L_net(num=64)
self.R_net = R_net(num=64)
self.N_net = N_net(num=64)
def forward(self, input):
x = self.N_net(input)
L = self.L_net(x)
R = self.R_net(x)
return L, R, x
模型结构比较简单是几个卷积+relu的组合,以上为训练模型结构。推理时,会对L图进行gamma增强后与反射图处理,结构如下所示:
L, R, X = model(input)
I = torch.pow(L,0.2) * R # default=0.2, LOL=0.14.
此与论文给出的流程图对应。
2. 数据加载
由于该篇论文选用的是多曝光的成对数据,因此它只需要加载某一个文件夹中的不同曝光数据即可完成训练,如dataset.py文件所示,:
class DatasetFromFolder(data.Dataset):
def __init__(self, data_dir, transform=None):
super(DatasetFromFolder, self).__init__()
self.data_dir = data_dir
self.transform = transform
def __getitem__(self, index):
index = index
data_filenames = [join(join(self.data_dir, str(index+1)), x) for x in listdir(join(self.data_dir, str(index+1))) if is_image_file(x)]
num = len(data_filenames)
index1 = random.randint(1,num)
index2 = random.randint(1,num)
while abs(index1 - index2) == 0:
index2 = random.randint(1,num)
im1 = load_img(data_filenames[index1-1])
im2 = load_img(data_filenames[index2-1])
_, file1 = os.path.split(data_filenames[index1-1])
_, file2 = os.path.split(data_filenames[index2-1])
seed = np.random.randint(123456789) # make a seed with numpy generator
if self.transform:
random.seed(seed) # apply this seed to img tranfsorms
torch.manual_seed(seed) # needed for torchvision 0.7
im1 = self.transform(im1)
random.seed(seed)
torch.manual_seed(seed)
im2 = self.transform(im2)
return im1, im2, file1, file2
def __len__(self):
return 324 # for custom datasets, please check the dataset size and modify this number
通过在提前准备好的文件夹中选出2个不同的文件,完成多曝光数据的准备。
3. 训练流程
位于main.py文件中,完成了R正则损失和Retinex损失的计算。
def train():
model.train()
loss_print = 0
for iteration, batch in enumerate(training_data_loader, 1):
im1, im2, file1, file2 = batch[0], batch[1], batch[2], batch[3]
im1 = im1.cuda()
im2 = im2.cuda()
L1, R1, X1 = model(im1)
L2, R2, X2 = model(im2)
loss1 = C_loss(R1, R2)
loss2 = R_loss(L1, R1, im1, X1)
loss3 = P_loss(im1, X1)
loss = loss1 * 1 + loss2 * 1 + loss3 * 500
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_print = loss_print + loss.item()
if iteration % 10 == 0:
print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={}.".format(epoch,
iteration, len(training_data_loader), loss_print, optimizer.param_groups[0]['lr']))
可以看到损失分为3个部分:
- loss1代表的是R正则,两个不同曝光的图像它们的R图一样。
- loss2代表的是Retinex假设损失,分解后的结果需要满足假设。
- loss3是一个降噪后的保真度损失。
其中所有损失的具体计算在util.py中。
def gradient(img):
height = img.size(2)
width = img.size(3)
gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()
gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()
return gradient_h, gradient_w
def tv_loss(illumination):
gradient_illu_h, gradient_illu_w = gradient(illumination)
loss_h = gradient_illu_h
loss_w = gradient_illu_w
loss = loss_h.mean() + loss_w.mean()
return loss
def C_loss(R1, R2):
loss = torch.nn.MSELoss()(R1, R2)
return loss
def R_loss(L1, R1, im1, X1):
max_rgb1, _ = torch.max(im1, 1)
max_rgb1 = max_rgb1.unsqueeze(1)
loss1 = torch.nn.MSELoss()(L1*R1, X1) + torch.nn.MSELoss()(R1, X1/L1.detach())
loss2 = torch.nn.MSELoss()(L1, max_rgb1) + tv_loss(L1)
return loss1 + loss2
def P_loss(im1, X1):
loss = torch.nn.MSELoss()(im1, X1)
return loss
与讲解中公式对应。
3、总结
代码实现核心的部分讲解完毕,本文实现了一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,在简单的网络结构下实际了降噪增强的效果。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

1529

被折叠的 条评论
为什么被折叠?



