【LUT技术专题】AdaInt代码讲解

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

本文是对AdaInt技术的代码解读,原文解读请看AdaInt文章讲解

1、原文概要

AdaInt针对 3DLUT均匀采样的LUT模块提出了一个采样间隔自适应的方法,可以减少非线性段LUT插值的误差,从而达到更好的效果。
AdaInt的整体流程如下所示:
在这里插入图片描述
可以看到整体是基于3D-LUT的框架,通过一个轻量的CNN结构得到输出code,输出code给到Weights Predictor预测weight,同样的输出code给到AdaInt模块得到采样间隔 Q Q Q,weight将基础LUT进行融合得到颜色表 T T T,采样间隔Q进行变换得到实际的非均匀采样点位置 P P P T T T P P P进行非均匀3DLUT的渲染得到采样的3DLUT,最后使用AiLUT-Transform算子对输入图像进行一个变换,得到增强图像。

2、代码结构

代码整体结构如下
在这里插入图片描述
代码基于mmedit框架构建,MMEditing 来自 OpenMMLab 项目,是基于 PyTorch 的图像和视频编辑开源工具箱。它目前包含了常见的编辑任务,比如图像修复,图像抠图,超分辨率和生成模型。与其类似的框架还有basicsr,在开发中使用事半功倍,强烈推荐。
这里就不讲解框架相关内容,主要讲解跟本文相关的核心代码。核心代码位于adaint文件夹中,如下所示
在这里插入图片描述
ailut_transform中放着跟最终插值相关的cpp代码实现,model.py中是最核心的部分,包含了自适应间隔和各个子网络模块的实现。

3 、核心代码模块

model.py 文件

这个文件包含了AdaInt文章中关于backbone、Weights Predictor、AdaInt模块的实现,另外还有生成采样3DLUT和一次迭代的过程。

1. TPAMIBackbone类

此为轻量CNN结构的实现,作者代码中放了2个不同的选项,分别是3DLUT中使用的TPAMIBackbone,另一个是Resnet18的实现,这里只放3DLUT的backbone。

class TPAMIBackbone(nn.Sequential):
    r"""The 5-layer CNN backbone module in [TPAMI 3D-LUT]
        (https://github.com/HuiZeng/Image-Adaptive-3DLUT).

    Args:
        pretrained (bool, optional): [ignored].
        input_resolution (int, optional): Resolution for pre-downsampling. Default: 256.
        extra_pooling (bool, optional): Whether to insert an extra pooling layer
            at the very end of the module to reduce the number of parameters of
            the subsequent module. Default: False.
    """

    def __init__(self, pretrained=False, input_resolution=256, extra_pooling=False):
        body = [
            BasicBlock(3, 16, stride=2, norm=True),
            BasicBlock(16, 32, stride=2, norm=True),
            BasicBlock(32, 64, stride=2, norm=True),
            BasicBlock(64, 128, stride=2, norm=True),
            BasicBlock(128, 128, stride=2),
            nn.Dropout(p=0.5),
        ]
        if extra_pooling:
            body.append(nn.AdaptiveAvgPool2d(2))
        super().__init__(*body)
        self.input_resolution = input_resolution
        self.out_channels = 128 * (4 if extra_pooling else 64)

    def forward(self, imgs):
        imgs = F.interpolate(imgs, size=(self.input_resolution,) * 2,
            mode='bilinear', align_corners=False)
        return super().forward(imgs).view(imgs.shape[0], -1)

可以看到,输入首先进行一个resize到256尺寸,然后经过一系列网络结构,最后根据选项选择是否进行avgpool,但最后会进行view,将空间全部放到通道上,方便后续使用。

其中的BasicBlock实现如下:

class BasicBlock(nn.Sequential):
    r"""The basic block module (Conv+LeakyReLU[+InstanceNorm]).
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm=False):
        body = [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=1),
            nn.LeakyReLU(0.2)
        ]
        if norm:
            body.append(nn.InstanceNorm2d(out_channels, affine=True))
        super(BasicBlock, self).__init__(*body)

其实就是一个简单的卷积,搭配了一个激活函数,根据normalization选项的不同插入InstanceNorm。

2. LUTGenerator类

该类实现了AdaInt中的weights和basicLUT:

class LUTGenerator(nn.Module):
    r"""The LUT generator module (mapping h).

    Args:
        n_colors (int): Number of input color channels.
        n_vertices (int): Number of sampling points along each lattice dimension.
        n_feats (int): Dimension of the input image representation vector.
        n_ranks (int): Number of ranks in the mapping h (or the number of basis LUTs).
    """

    def __init__(self, n_colors, n_vertices, n_feats, n_ranks) -> None:
        super().__init__()

        # h0
        self.weights_generator = nn.Linear(n_feats, n_ranks)
        # h1
        self.basis_luts_bank = nn.Linear(
            n_ranks, n_colors * (n_vertices ** n_colors), bias=False)

        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.n_feats = n_feats
        self.n_ranks = n_ranks

    def init_weights(self):
        r"""Init weights for models.

        For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
            [TPAMI 3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).

        """
        nn.init.ones_(self.weights_generator.bias)
        identity_lut = torch.stack([
            torch.stack(
                torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(self.n_colors)]),
                dim=0).div(self.n_vertices - 1).flip(0),
            *[torch.zeros(
                self.n_colors, *((self.n_vertices,) * self.n_colors)) for _ in range(self.n_ranks - 1)]
            ], dim=0).view(self.n_ranks, -1)
        self.basis_luts_bank.weight.data.copy_(identity_lut.t())

    def forward(self, x):
        weights = self.weights_generator(x)
        luts = self.basis_luts_bank(weights)
        luts = luts.view(x.shape[0], -1, *((self.n_vertices,) * self.n_colors))
        return weights, luts

    def regularizations(self, smoothness, monotonicity):
        basis_luts = self.basis_luts_bank.weight.t().view(
            self.n_ranks, self.n_colors, *((self.n_vertices,) * self.n_colors))
        tv, mn = 0, 0
        for i in range(2, basis_luts.ndimension()):
            diff = torch.diff(basis_luts.flip(i), dim=i)
            tv += torch.square(diff).sum(0).mean()
            mn += F.relu(diff).sum(0).mean()
        reg_smoothness = smoothness * tv
        reg_monotonicity = monotonicity * mn
        return reg_smoothness, reg_monotonicity

此跟3DLUT中一样,不过实现方式有一些不一样,weights_generator是一个Linear将CNN提取的feat转换为n_ranks个weight,n_ranks跟使用的LUT个数一样,这里默认是3,basicLUT的实现使用了Linear层的weight来进行保存,这里可以看到这个线性层的weight的shape是(n_ranks,n_colors * (n_vertices ** n_colors)),跟3DLUT里面一个一个初始化LUT是一样的结果。regularizations函数实现了平滑和单调损失,这个跟3DLUT的实现一样。

3. AdaInt类

该类实现自适应间隔功能。

class AdaInt(nn.Module):
    r"""The Adaptive Interval Learning (AdaInt) module (mapping g).

    It consists of a single fully-connected layer and some post-process operations.

    Args:
        n_colors (int): Number of input color channels.
        n_vertices (int): Number of sampling points along each lattice dimension.
        n_feats (int): Dimension of the input image representation vector.
        adaint_share (bool, optional): Whether to enable Share-AdaInt. Default: False.
    """

    def __init__(self, n_colors, n_vertices, n_feats, adaint_share=False) -> None:
        super().__init__()
        repeat_factor = n_colors if not adaint_share else 1
        self.intervals_generator = nn.Linear(
            n_feats, (n_vertices - 1) * repeat_factor)

        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.adaint_share = adaint_share

    def init_weights(self):
        r"""Init weights for models.

        We use all-zero and all-one initializations for its weights and bias, respectively.
        """
        nn.init.zeros_(self.intervals_generator.weight)
        nn.init.ones_(self.intervals_generator.bias)

    def forward(self, x):
        r"""Forward function for AdaInt module.

        Args:
            x (tensor): Input image representation, shape (b, f).
        Returns:
            Tensor: Sampling coordinates along each lattice dimension, shape (b, c, d).
        """
        x = x.view(x.shape[0], -1)
        intervals = self.intervals_generator(x).view(
            x.shape[0], -1, self.n_vertices - 1)
        if self.adaint_share:
            intervals = intervals.repeat_interleave(self.n_colors, dim=1)
        intervals = intervals.softmax(-1)
        vertices = F.pad(intervals.cumsum(-1), (1, 0), 'constant', 0)
        return vertices

从前向可以看到,输入CNN预测的code预测1或3(根据adaint_share变量而定) * (n_vertices-1)个间隔,自然adaint_share为true时,需要repeat通道次,后续在通道维度上做softmax做归一化,最后进行cumsum进行累加,得到实际的采样点,这跟前面讲解的算法流程是一样的。

4. AiLUT

此类实现的是整体的推理流程,这里只放了核心的一部分代码。

@MODELS.register_module()
class AiLUT(BaseModel):
    r"""Adaptive-Interval 3D Lookup Table for real-time image enhancement.

    Args:
        n_ranks (int, optional): Number of ranks in the mapping h
            (or the number of basis LUTs). Default: 3.
        n_vertices (int, optional): Number of sampling points along
            each lattice dimension. Default: 33.
        en_adaint (bool, optional): Whether to enable AdaInt. Default: True.
        en_adaint_share (bool, optional): Whether to enable Share-AdaInt.
            Only used when `en_adaint` is True. Default: False.
        backbone (str, optional): Backbone architecture to use. Can be either 'tpami'
            or 'res18'. Default: 'tpami'.
        pretrained (bool, optional): Whether to use ImageNet-pretrained weights.
            Only used when `backbone` is 'res18'. Default: None.
        n_colors (int, optional): Number of input color channels. Default: 3.
        sparse_factor (float, optional): Loss weight for the sparse regularization term.
            Default: 0.0001.
        smooth_factor (float, optional): Loss weight for the smoothness regularization term.
            Default: 0.
        monotonicity_factor (float, optional): Loss weight for the monotonicaity
            regularization term. Default: 10.0.
        recons_loss (dict, optional): Config for pixel-wise reconstruction loss.
        train_cfg (dict, optional): Config for training. Default: None.
        test_cfg (dict, optional): Config for testing. Default: None.
    """

    allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}

    def __init__(self,
        n_ranks=3,
        n_vertices=33,
        en_adaint=True,
        en_adaint_share=False,
        backbone='tpami',
        pretrained=False,
        n_colors=3,
        sparse_factor=0.0001,
        smooth_factor=0,
        monotonicity_factor=10.0,
        recons_loss=dict(type='L2Loss', loss_weight=1.0, reduction='mean'),
        train_cfg=None,
        test_cfg=None):

        super().__init__()

        assert backbone.lower() in ['tpami', 'res18']

        # mapping f
        self.backbone = dict(
            tpami=TPAMIBackbone,
            res18=Res18Backbone)[backbone.lower()](pretrained, extra_pooling=en_adaint)

        # mapping h
        self.lut_generator = LUTGenerator(
            n_colors, n_vertices, self.backbone.out_channels, n_ranks)

        # mapping g
        if en_adaint:
            self.adaint = AdaInt(
                n_colors, n_vertices, self.backbone.out_channels, en_adaint_share)
        else:
            uniform_vertices = torch.arange(n_vertices).div(n_vertices - 1) \
                                    .repeat(n_colors, 1)
            self.register_buffer('uniform_vertices', uniform_vertices.unsqueeze(0))

        self.n_ranks = n_ranks
        self.n_colors = n_colors
        self.n_vertices = n_vertices
        self.en_adaint = en_adaint
        self.sparse_factor = sparse_factor
        self.smooth_factor = smooth_factor
        self.monotonicity_factor = monotonicity_factor
        self.backbone_name = backbone.lower()

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.fp16_enabled = False

        self.init_weights()

        self.recons_loss = build_loss(recons_loss)

        # fix AdaInt for some steps
        self.n_fix_iters = train_cfg.get('n_fix_iters', 0) if train_cfg else 0
        self.adaint_fixed = False
        self.register_buffer('cnt_iters', torch.zeros(1))

    def init_weights(self):
        r"""Init weights for models.

        For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
            [TPAMI 3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
        For the mapping g (`adaint`), we use all-zero and all-one initializations for its weights
        and bias, respectively.
        """
        def special_initilization(m):
            classname = m.__class__.__name__
            if 'Conv' in classname:
                nn.init.xavier_normal_(m.weight.data)
            elif 'InstanceNorm' in classname:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0.0)
        if self.backbone_name not in ['res18']:
            self.apply(special_initilization)
        self.lut_generator.init_weights()
        if self.en_adaint:
            self.adaint.init_weights()

    def forward_dummy(self, imgs):
        r"""The real implementation of model forward.

        Args:
            img (Tensor): Input image, shape (b, c, h, w).
        Returns:
            tuple(Tensor, Tensor, Tensor):
                Output image, LUT weights, Sampling Coordinates.
        """
        # E: (b, f)
        codes = self.backbone(imgs)
        # (b, m), T: (b, c, d, d, d)
        weights, luts = self.lut_generator(codes)
        # \hat{P}: (b, c, d)
        if self.en_adaint:
            vertices = self.adaint(codes)
        else:
            vertices = self.uniform_vertices

        outs = ailut_transform(imgs, luts, vertices)

        return outs, weights, vertices

首先初始化了所有我们需要的模块,这样在forward_dummy前向中,可以看到首先提取CNN特征codes,codes送到lut_generator和adaint分别得到luts和vertices,最后进行一个ailut_transform得到增强的结果即可。

ailut_transform_cpu.cpp 文件

ailut_transform的实现代码,这里展示的cpu的版本,文件在adaint/ailut_transform/ailut/csrc中。首先带自适应间隔的启动函数是ailut_transform_cpu_forward,如下所示。

void ailut_transform_cpu_forward(
    const torch::Tensor &input,
    const torch::Tensor &lut,
    const torch::Tensor &vertices,
    torch::Tensor output) {

    /* retrieve some meta-information of the input tensors */
    int batch_size = input.size(0);
    int height     = input.size(2);
    int width      = input.size(3);

    int num_channels = lut.size(1);
    int stride_lut   = lut.size(2);

    int num_kernels = height * width;

    for (int elt = 0; elt < batch_size; ++elt) {
        AT_DISPATCH_FLOATING_TYPES(
            input.scalar_type(), "ailut_transform_cpu_forward", ([&] {
                const scalar_t *data_inp = input[elt].data_ptr<scalar_t>();
                const scalar_t *data_lut = lut[elt].data_ptr<scalar_t>();
                const scalar_t *data_anc = vertices[elt].data_ptr<scalar_t>();
                scalar_t *data_col = output[elt].data_ptr<scalar_t>();

                ailut_transform_3d_cpu_forward_impl(
                    num_kernels, data_inp, data_lut, data_anc,
                    height, width, stride_lut, num_channels,
                    data_col);
            }));
    }
}

这里可以看到其在batch维度调用的是ailut_transform_3d_cpu_forward_impl函数,送的参数含义都是较明确的,其的实现如下:

template <typename scalar_t>
void ailut_transform_3d_cpu_forward_impl(
        const int n,
        const scalar_t* __restrict__ data_inp,
        const scalar_t* __restrict__ data_lut,
        const scalar_t* __restrict__ data_anc,
        const int height,
        const int width,
        const int stride_lut,
        const int num_channels,
        scalar_t* __restrict__ data_col) {

    const static scalar_t eps = 1e-10;

    for (int index = 0; index < n; ++index) {

        /* retrieve rgb value of the pixel */
        const scalar_t r = data_inp[index];
        const scalar_t g = data_inp[index + height * width];
        const scalar_t b = data_inp[index + height * width * 2];

        /* retrieve index of the interpolation verticess */
        const int32_t rid = lower_bound(data_anc, 0, stride_lut, r);
        const int32_t gid = lower_bound(data_anc, stride_lut, stride_lut * 2, g);
        const int32_t bid = lower_bound(data_anc, stride_lut * 2, stride_lut * 3, b);

        /* utility variables for indexing */
        const int stride_lut_2 = stride_lut * stride_lut;
        const int stride_lut_3 = stride_lut_2 * stride_lut;

        /* retrieve the interpolation verticess (number of 8 in case of trilinear interpolation) */
        const int id000 = (rid    ) + stride_lut * (gid    ) + stride_lut_2 * (bid    );
        const int id100 = (rid + 1) + stride_lut * (gid    ) + stride_lut_2 * (bid    );
        const int id010 = (rid    ) + stride_lut * (gid + 1) + stride_lut_2 * (bid    );
        const int id110 = (rid + 1) + stride_lut * (gid + 1) + stride_lut_2 * (bid    );
        const int id001 = (rid    ) + stride_lut * (gid    ) + stride_lut_2 * (bid + 1);
        const int id101 = (rid + 1) + stride_lut * (gid    ) + stride_lut_2 * (bid + 1);
        const int id011 = (rid    ) + stride_lut * (gid + 1) + stride_lut_2 * (bid + 1);
        const int id111 = (rid + 1) + stride_lut * (gid + 1) + stride_lut_2 * (bid + 1);

        /* compute interpolation weights */
        const scalar_t r0 = data_anc[rid];
        const scalar_t r1 = data_anc[rid + 1];
        const scalar_t g0 = data_anc[gid + stride_lut];
        const scalar_t g1 = data_anc[gid + stride_lut + 1];
        const scalar_t b0 = data_anc[bid + stride_lut * 2];
        const scalar_t b1 = data_anc[bid + stride_lut * 2 + 1];

        const scalar_t rd = (r - r0) / (r1 - r0 + eps);
        const scalar_t gd = (g - g0) / (g1 - g0 + eps);
        const scalar_t bd = (b - b0) / (b1 - b0 + eps);

        const scalar_t w000 = (1 - rd) * (1 - gd) * (1 - bd);
        const scalar_t w100 = (    rd) * (1 - gd) * (1 - bd);
        const scalar_t w010 = (1 - rd) * (    gd) * (1 - bd);
        const scalar_t w110 = (    rd) * (    gd) * (1 - bd);
        const scalar_t w001 = (1 - rd) * (1 - gd) * (    bd);
        const scalar_t w101 = (    rd) * (1 - gd) * (    bd);
        const scalar_t w011 = (1 - rd) * (    gd) * (    bd);
        const scalar_t w111 = (    rd) * (    gd) * (    bd);

        /* Execute the interpolation */
        for (int i = 0; i < num_channels; ++i) {
            data_col[index + height * width * i] =
                w000 * data_lut[id000 + stride_lut_3 * i] + w100 * data_lut[id100 + stride_lut_3 * i] +
                w010 * data_lut[id010 + stride_lut_3 * i] + w110 * data_lut[id110 + stride_lut_3 * i] +
                w001 * data_lut[id001 + stride_lut_3 * i] + w101 * data_lut[id101 + stride_lut_3 * i] +
                w011 * data_lut[id011 + stride_lut_3 * i] + w111 * data_lut[id111 + stride_lut_3 * i];
        }
    }
}

可以看到其在空间上进行迭代,以一次迭代为例,首先去到r、g、b的值,然后在lut中查询的下界,使用的是lower_bound,二分查找,实现如下所示。

/* std::clamp is only available since c++17 */
template <typename scalar_t>
inline constexpr const scalar_t& clamp(
    const scalar_t& v, const scalar_t& lo, const scalar_t& hi)
{
    return (v < lo) ? lo : ((v > hi) ? hi : v);
}


/* binary search on a sorted array to find and clamp the lower bound */
template <typename scalar_t>
inline int32_t lower_bound(
        const scalar_t *data_ss,
        int32_t start,
        int32_t end,
        scalar_t val) {

    const int32_t ori_start = start;
    const int32_t upper_bound = end - start - 2;
    while (start < end) {
        int64_t mid = start + ((end - start) >> 1);
        if (!(data_ss[mid] >= val)) {
            start = mid + 1;
        }
        else {
            end = mid;
        }
    }
    return clamp(start - ori_start - 1, 0, upper_bound);
}

因为lut是递增的,因此可以根据当前r、g、b的值二分找到下界对应的采样点。回到插值过程,当我们找到下界采样点后,就可以通过加1和组合找到所有的8个点了,之后进行插值就可以得到实际结果。

3、总结

代码实现核心的部分讲解完毕,该篇论文相对来说开源的质量就比4DLUT会高很多。不过这里需要指出一个点:大家看了实现后就会发现,该篇论文的自适应间隔,应该理解为先对原始r、g、b做一次映射,然后再进行3DLUT插值,如果将其理解为3DLUT自适应间隔减小量化误差,就跟实际实现不对应。


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值