【LLIE技术专题】基于成对低光图像学习自适应先验方案代码讲解

该文章已生成可运行项目,

本文是基于成对低光图像学习自适应先验方案的代码讲解,文章讲解可看链接PairLLE

1、原文概要

本文PairLIE 是一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,与一般方法的区别如下图所示:
在这里插入图片描述
常规的方法是(a),本文提出的方法是(b),可以看到本文的方法需要2张图来优化。

2、代码结构

代码整体结构如下
在这里插入图片描述

核心代码模块包含模型结构、数据加载、训练流程3部分。

3 、核心代码模块

1. 模型结构

模型包含用于完成降噪恒等映射的N_net,以及预测反射图和光照图的R_net和L_net,代码在net/net.py中。


class L_net(nn.Module):
    def __init__(self, num=64):
        super(L_net, self).__init__()
        self.L_net = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(), 
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),   
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, 1, 3, 1, 0),
        )

    def forward(self, input):
        return torch.sigmoid(self.L_net(input))


class R_net(nn.Module):
    def __init__(self, num=64):
        super(R_net, self).__init__()

        self.R_net = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, num, 3, 1, 0),
            nn.ReLU(), 
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),            
            nn.ReLU(),   
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, 3, 3, 1, 0),
        )

    def forward(self, input):
        return torch.sigmoid(self.R_net(input))

class N_net(nn.Module):
    def __init__(self, num=64):
        super(N_net, self).__init__()
        self.N_net = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, num, 3, 1, 0),
            nn.ReLU(), 
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),
            nn.ReLU(),               
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, num, 3, 1, 0),            
            nn.ReLU(),   
            nn.ReflectionPad2d(1),
            nn.Conv2d(num, 3, 3, 1, 0),
        )

    def forward(self, input):
        return torch.sigmoid(self.N_net(input))


class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()        
        self.L_net = L_net(num=64)
        self.R_net = R_net(num=64)
        self.N_net = N_net(num=64)        

    def forward(self, input):
        x = self.N_net(input)
        L = self.L_net(x)
        R = self.R_net(x)
        return L, R, x

模型结构比较简单是几个卷积+relu的组合,以上为训练模型结构。推理时,会对L图进行gamma增强后与反射图处理,结构如下所示:

L, R, X = model(input)    
I = torch.pow(L,0.2) * R  # default=0.2, LOL=0.14.

此与论文给出的流程图对应。

2. 数据加载

由于该篇论文选用的是多曝光的成对数据,因此它只需要加载某一个文件夹中的不同曝光数据即可完成训练,如dataset.py文件所示,:

class DatasetFromFolder(data.Dataset):
    def __init__(self, data_dir, transform=None):
        super(DatasetFromFolder, self).__init__()
        self.data_dir = data_dir
        self.transform = transform

    def __getitem__(self, index):
        index = index
        data_filenames = [join(join(self.data_dir, str(index+1)), x) for x in listdir(join(self.data_dir, str(index+1))) if is_image_file(x)]
        num = len(data_filenames)
        index1 = random.randint(1,num)
        index2 = random.randint(1,num)
        while abs(index1 - index2) == 0:
            index2 = random.randint(1,num)

        im1 = load_img(data_filenames[index1-1])
        im2 = load_img(data_filenames[index2-1])

        _, file1 = os.path.split(data_filenames[index1-1])
        _, file2 = os.path.split(data_filenames[index2-1])

        seed = np.random.randint(123456789) # make a seed with numpy generator 
        if self.transform:
            random.seed(seed) # apply this seed to img tranfsorms
            torch.manual_seed(seed) # needed for torchvision 0.7
            im1 = self.transform(im1)
            random.seed(seed)
            torch.manual_seed(seed)         
            im2 = self.transform(im2)        
        return im1, im2, file1, file2

    def __len__(self):
        return 324 # for custom datasets, please check the dataset size and modify this number

通过在提前准备好的文件夹中选出2个不同的文件,完成多曝光数据的准备。

3. 训练流程

位于main.py文件中,完成了R正则损失和Retinex损失的计算。

def train():
    model.train()
    loss_print = 0
    for iteration, batch in enumerate(training_data_loader, 1):

        im1, im2, file1, file2 = batch[0], batch[1], batch[2], batch[3]
        im1 = im1.cuda()
        im2 = im2.cuda()
        L1, R1, X1 = model(im1)
        L2, R2, X2 = model(im2)   
        loss1 = C_loss(R1, R2)
        loss2 = R_loss(L1, R1, im1, X1)
        loss3 = P_loss(im1, X1)
        loss =  loss1 * 1 + loss2 * 1 + loss3 * 500

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_print = loss_print + loss.item()
        if iteration % 10 == 0:
            print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={}.".format(epoch,
                iteration, len(training_data_loader), loss_print, optimizer.param_groups[0]['lr']))

可以看到损失分为3个部分:

  1. loss1代表的是R正则,两个不同曝光的图像它们的R图一样。
  2. loss2代表的是Retinex假设损失,分解后的结果需要满足假设。
  3. loss3是一个降噪后的保真度损失。

其中所有损失的具体计算在util.py中。

def gradient(img):
    height = img.size(2)
    width = img.size(3)
    gradient_h = (img[:,:,2:,:]-img[:,:,:height-2,:]).abs()
    gradient_w = (img[:, :, :, 2:] - img[:, :, :, :width-2]).abs()
    return gradient_h, gradient_w

def tv_loss(illumination):
    gradient_illu_h, gradient_illu_w = gradient(illumination)
    loss_h = gradient_illu_h
    loss_w = gradient_illu_w
    loss = loss_h.mean() + loss_w.mean()
    return loss

def C_loss(R1, R2):
    loss = torch.nn.MSELoss()(R1, R2) 
    return loss

def R_loss(L1, R1, im1, X1):
    max_rgb1, _ = torch.max(im1, 1)
    max_rgb1 = max_rgb1.unsqueeze(1) 
    loss1 = torch.nn.MSELoss()(L1*R1, X1) + torch.nn.MSELoss()(R1, X1/L1.detach())
    loss2 = torch.nn.MSELoss()(L1, max_rgb1) + tv_loss(L1)
    return loss1 + loss2

def P_loss(im1, X1):
    loss = torch.nn.MSELoss()(im1, X1)
    return loss

与讲解中公式对应。

3、总结

代码实现核心的部分讲解完毕,本文实现了一种无监督低光图像增强方法,核心是利用成对低光图像(同场景、不同光照) 学习自适应先验,减少人工先验依赖,在简单的网络结构下实际了降噪增强的效果。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

本文章已经生成可运行项目
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值