本文是对AdaInt技术的代码解读,原文解读请看AdaInt文章讲解。
1、原文概要
AdaInt针对 3DLUT均匀采样的LUT模块提出了一个采样间隔自适应的方法,可以减少非线性段LUT插值的误差,从而达到更好的效果。
AdaInt的整体流程如下所示:

可以看到整体是基于3D-LUT的框架,通过一个轻量的CNN结构得到输出code,输出code给到Weights Predictor预测weight,同样的输出code给到AdaInt模块得到采样间隔
Q
Q
Q,weight将基础LUT进行融合得到颜色表
T
T
T,采样间隔Q进行变换得到实际的非均匀采样点位置
P
P
P,
T
T
T和
P
P
P进行非均匀3DLUT的渲染得到采样的3DLUT,最后使用AiLUT-Transform算子对输入图像进行一个变换,得到增强图像。
2、代码结构
代码整体结构如下:

代码基于mmedit框架构建,MMEditing 来自 OpenMMLab 项目,是基于 PyTorch 的图像和视频编辑开源工具箱。它目前包含了常见的编辑任务,比如图像修复,图像抠图,超分辨率和生成模型。与其类似的框架还有basicsr,在开发中使用事半功倍,强烈推荐。
这里就不讲解框架相关内容,主要讲解跟本文相关的核心代码。核心代码位于adaint文件夹中,如下所示:

ailut_transform中放着跟最终插值相关的cpp代码实现,model.py中是最核心的部分,包含了自适应间隔和各个子网络模块的实现。
3 、核心代码模块
model.py 文件
这个文件包含了AdaInt文章中关于backbone、Weights Predictor、AdaInt模块的实现,另外还有生成采样3DLUT和一次迭代的过程。
1. TPAMIBackbone类
此为轻量CNN结构的实现,作者代码中放了2个不同的选项,分别是3DLUT中使用的TPAMIBackbone,另一个是Resnet18的实现,这里只放3DLUT的backbone。
class TPAMIBackbone(nn.Sequential):
r"""The 5-layer CNN backbone module in [TPAMI 3D-LUT]
(https://github.com/HuiZeng/Image-Adaptive-3DLUT).
Args:
pretrained (bool, optional): [ignored].
input_resolution (int, optional): Resolution for pre-downsampling. Default: 256.
extra_pooling (bool, optional): Whether to insert an extra pooling layer
at the very end of the module to reduce the number of parameters of
the subsequent module. Default: False.
"""
def __init__(self, pretrained=False, input_resolution=256, extra_pooling=False):
body = [
BasicBlock(3, 16, stride=2, norm=True),
BasicBlock(16, 32, stride=2, norm=True),
BasicBlock(32, 64, stride=2, norm=True),
BasicBlock(64, 128, stride=2, norm=True),
BasicBlock(128, 128, stride=2),
nn.Dropout(p=0.5),
]
if extra_pooling:
body.append(nn.AdaptiveAvgPool2d(2))
super().__init__(*body)
self.input_resolution = input_resolution
self.out_channels = 128 * (4 if extra_pooling else 64)
def forward(self, imgs):
imgs = F.interpolate(imgs, size=(self.input_resolution,) * 2,
mode='bilinear', align_corners=False)
return super().forward(imgs).view(imgs.shape[0], -1)
可以看到,输入首先进行一个resize到256尺寸,然后经过一系列网络结构,最后根据选项选择是否进行avgpool,但最后会进行view,将空间全部放到通道上,方便后续使用。
其中的BasicBlock实现如下:
class BasicBlock(nn.Sequential):
r"""The basic block module (Conv+LeakyReLU[+InstanceNorm]).
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm=False):
body = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=1),
nn.LeakyReLU(0.2)
]
if norm:
body.append(nn.InstanceNorm2d(out_channels, affine=True))
super(BasicBlock, self).__init__(*body)
其实就是一个简单的卷积,搭配了一个激活函数,根据normalization选项的不同插入InstanceNorm。
2. LUTGenerator类
该类实现了AdaInt中的weights和basicLUT:
class LUTGenerator(nn.Module):
r"""The LUT generator module (mapping h).
Args:
n_colors (int): Number of input color channels.
n_vertices (int): Number of sampling points along each lattice dimension.
n_feats (int): Dimension of the input image representation vector.
n_ranks (int): Number of ranks in the mapping h (or the number of basis LUTs).
"""
def __init__(self, n_colors, n_vertices, n_feats, n_ranks) -> None:
super().__init__()
# h0
self.weights_generator = nn.Linear(n_feats, n_ranks)
# h1
self.basis_luts_bank = nn.Linear(
n_ranks, n_colors * (n_vertices ** n_colors), bias=False)
self.n_colors = n_colors
self.n_vertices = n_vertices
self.n_feats = n_feats
self.n_ranks = n_ranks
def init_weights(self):
r"""Init weights for models.
For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
[TPAMI 3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
"""
nn.init.ones_(self.weights_generator.bias)
identity_lut = torch.stack([
torch.stack(
torch.meshgrid(*[torch.arange(self.n_vertices) for _ in range(self.n_colors)]),
dim=0).div(self.n_vertices - 1).flip(0),
*[torch.zeros(
self.n_colors, *((self.n_vertices,) * self.n_colors)) for _ in range(self.n_ranks - 1)]
], dim=0).view(self.n_ranks, -1)
self.basis_luts_bank.weight.data.copy_(identity_lut.t())
def forward(self, x):
weights = self.weights_generator(x)
luts = self.basis_luts_bank(weights)
luts = luts.view(x.shape[0], -1, *((self.n_vertices,) * self.n_colors))
return weights, luts
def regularizations(self, smoothness, monotonicity):
basis_luts = self.basis_luts_bank.weight.t().view(
self.n_ranks, self.n_colors, *((self.n_vertices,) * self.n_colors))
tv, mn = 0, 0
for i in range(2, basis_luts.ndimension()):
diff = torch.diff(basis_luts.flip(i), dim=i)
tv += torch.square(diff).sum(0).mean()
mn += F.relu(diff).sum(0).mean()
reg_smoothness = smoothness * tv
reg_monotonicity = monotonicity * mn
return reg_smoothness, reg_monotonicity
此跟3DLUT中一样,不过实现方式有一些不一样,weights_generator是一个Linear将CNN提取的feat转换为n_ranks个weight,n_ranks跟使用的LUT个数一样,这里默认是3,basicLUT的实现使用了Linear层的weight来进行保存,这里可以看到这个线性层的weight的shape是(n_ranks,n_colors * (n_vertices ** n_colors)),跟3DLUT里面一个一个初始化LUT是一样的结果。regularizations函数实现了平滑和单调损失,这个跟3DLUT的实现一样。
3. AdaInt类
该类实现自适应间隔功能。
class AdaInt(nn.Module):
r"""The Adaptive Interval Learning (AdaInt) module (mapping g).
It consists of a single fully-connected layer and some post-process operations.
Args:
n_colors (int): Number of input color channels.
n_vertices (int): Number of sampling points along each lattice dimension.
n_feats (int): Dimension of the input image representation vector.
adaint_share (bool, optional): Whether to enable Share-AdaInt. Default: False.
"""
def __init__(self, n_colors, n_vertices, n_feats, adaint_share=False) -> None:
super().__init__()
repeat_factor = n_colors if not adaint_share else 1
self.intervals_generator = nn.Linear(
n_feats, (n_vertices - 1) * repeat_factor)
self.n_colors = n_colors
self.n_vertices = n_vertices
self.adaint_share = adaint_share
def init_weights(self):
r"""Init weights for models.
We use all-zero and all-one initializations for its weights and bias, respectively.
"""
nn.init.zeros_(self.intervals_generator.weight)
nn.init.ones_(self.intervals_generator.bias)
def forward(self, x):
r"""Forward function for AdaInt module.
Args:
x (tensor): Input image representation, shape (b, f).
Returns:
Tensor: Sampling coordinates along each lattice dimension, shape (b, c, d).
"""
x = x.view(x.shape[0], -1)
intervals = self.intervals_generator(x).view(
x.shape[0], -1, self.n_vertices - 1)
if self.adaint_share:
intervals = intervals.repeat_interleave(self.n_colors, dim=1)
intervals = intervals.softmax(-1)
vertices = F.pad(intervals.cumsum(-1), (1, 0), 'constant', 0)
return vertices
从前向可以看到,输入CNN预测的code预测1或3(根据adaint_share变量而定) * (n_vertices-1)个间隔,自然adaint_share为true时,需要repeat通道次,后续在通道维度上做softmax做归一化,最后进行cumsum进行累加,得到实际的采样点,这跟前面讲解的算法流程是一样的。
4. AiLUT
此类实现的是整体的推理流程,这里只放了核心的一部分代码。
@MODELS.register_module()
class AiLUT(BaseModel):
r"""Adaptive-Interval 3D Lookup Table for real-time image enhancement.
Args:
n_ranks (int, optional): Number of ranks in the mapping h
(or the number of basis LUTs). Default: 3.
n_vertices (int, optional): Number of sampling points along
each lattice dimension. Default: 33.
en_adaint (bool, optional): Whether to enable AdaInt. Default: True.
en_adaint_share (bool, optional): Whether to enable Share-AdaInt.
Only used when `en_adaint` is True. Default: False.
backbone (str, optional): Backbone architecture to use. Can be either 'tpami'
or 'res18'. Default: 'tpami'.
pretrained (bool, optional): Whether to use ImageNet-pretrained weights.
Only used when `backbone` is 'res18'. Default: None.
n_colors (int, optional): Number of input color channels. Default: 3.
sparse_factor (float, optional): Loss weight for the sparse regularization term.
Default: 0.0001.
smooth_factor (float, optional): Loss weight for the smoothness regularization term.
Default: 0.
monotonicity_factor (float, optional): Loss weight for the monotonicaity
regularization term. Default: 10.0.
recons_loss (dict, optional): Config for pixel-wise reconstruction loss.
train_cfg (dict, optional): Config for training. Default: None.
test_cfg (dict, optional): Config for testing. Default: None.
"""
allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}
def __init__(self,
n_ranks=3,
n_vertices=33,
en_adaint=True,
en_adaint_share=False,
backbone='tpami',
pretrained=False,
n_colors=3,
sparse_factor=0.0001,
smooth_factor=0,
monotonicity_factor=10.0,
recons_loss=dict(type='L2Loss', loss_weight=1.0, reduction='mean'),
train_cfg=None,
test_cfg=None):
super().__init__()
assert backbone.lower() in ['tpami', 'res18']
# mapping f
self.backbone = dict(
tpami=TPAMIBackbone,
res18=Res18Backbone)[backbone.lower()](pretrained, extra_pooling=en_adaint)
# mapping h
self.lut_generator = LUTGenerator(
n_colors, n_vertices, self.backbone.out_channels, n_ranks)
# mapping g
if en_adaint:
self.adaint = AdaInt(
n_colors, n_vertices, self.backbone.out_channels, en_adaint_share)
else:
uniform_vertices = torch.arange(n_vertices).div(n_vertices - 1) \
.repeat(n_colors, 1)
self.register_buffer('uniform_vertices', uniform_vertices.unsqueeze(0))
self.n_ranks = n_ranks
self.n_colors = n_colors
self.n_vertices = n_vertices
self.en_adaint = en_adaint
self.sparse_factor = sparse_factor
self.smooth_factor = smooth_factor
self.monotonicity_factor = monotonicity_factor
self.backbone_name = backbone.lower()
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.fp16_enabled = False
self.init_weights()
self.recons_loss = build_loss(recons_loss)
# fix AdaInt for some steps
self.n_fix_iters = train_cfg.get('n_fix_iters', 0) if train_cfg else 0
self.adaint_fixed = False
self.register_buffer('cnt_iters', torch.zeros(1))
def init_weights(self):
r"""Init weights for models.
For the mapping f (`backbone`) and h (`lut_generator`), we follow the initialization in
[TPAMI 3D-LUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT).
For the mapping g (`adaint`), we use all-zero and all-one initializations for its weights
and bias, respectively.
"""
def special_initilization(m):
classname = m.__class__.__name__
if 'Conv' in classname:
nn.init.xavier_normal_(m.weight.data)
elif 'InstanceNorm' in classname:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
if self.backbone_name not in ['res18']:
self.apply(special_initilization)
self.lut_generator.init_weights()
if self.en_adaint:
self.adaint.init_weights()
def forward_dummy(self, imgs):
r"""The real implementation of model forward.
Args:
img (Tensor): Input image, shape (b, c, h, w).
Returns:
tuple(Tensor, Tensor, Tensor):
Output image, LUT weights, Sampling Coordinates.
"""
# E: (b, f)
codes = self.backbone(imgs)
# (b, m), T: (b, c, d, d, d)
weights, luts = self.lut_generator(codes)
# \hat{P}: (b, c, d)
if self.en_adaint:
vertices = self.adaint(codes)
else:
vertices = self.uniform_vertices
outs = ailut_transform(imgs, luts, vertices)
return outs, weights, vertices
首先初始化了所有我们需要的模块,这样在forward_dummy前向中,可以看到首先提取CNN特征codes,codes送到lut_generator和adaint分别得到luts和vertices,最后进行一个ailut_transform得到增强的结果即可。
ailut_transform_cpu.cpp 文件
ailut_transform的实现代码,这里展示的cpu的版本,文件在adaint/ailut_transform/ailut/csrc中。首先带自适应间隔的启动函数是ailut_transform_cpu_forward,如下所示。
void ailut_transform_cpu_forward(
const torch::Tensor &input,
const torch::Tensor &lut,
const torch::Tensor &vertices,
torch::Tensor output) {
/* retrieve some meta-information of the input tensors */
int batch_size = input.size(0);
int height = input.size(2);
int width = input.size(3);
int num_channels = lut.size(1);
int stride_lut = lut.size(2);
int num_kernels = height * width;
for (int elt = 0; elt < batch_size; ++elt) {
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "ailut_transform_cpu_forward", ([&] {
const scalar_t *data_inp = input[elt].data_ptr<scalar_t>();
const scalar_t *data_lut = lut[elt].data_ptr<scalar_t>();
const scalar_t *data_anc = vertices[elt].data_ptr<scalar_t>();
scalar_t *data_col = output[elt].data_ptr<scalar_t>();
ailut_transform_3d_cpu_forward_impl(
num_kernels, data_inp, data_lut, data_anc,
height, width, stride_lut, num_channels,
data_col);
}));
}
}
这里可以看到其在batch维度调用的是ailut_transform_3d_cpu_forward_impl函数,送的参数含义都是较明确的,其的实现如下:
template <typename scalar_t>
void ailut_transform_3d_cpu_forward_impl(
const int n,
const scalar_t* __restrict__ data_inp,
const scalar_t* __restrict__ data_lut,
const scalar_t* __restrict__ data_anc,
const int height,
const int width,
const int stride_lut,
const int num_channels,
scalar_t* __restrict__ data_col) {
const static scalar_t eps = 1e-10;
for (int index = 0; index < n; ++index) {
/* retrieve rgb value of the pixel */
const scalar_t r = data_inp[index];
const scalar_t g = data_inp[index + height * width];
const scalar_t b = data_inp[index + height * width * 2];
/* retrieve index of the interpolation verticess */
const int32_t rid = lower_bound(data_anc, 0, stride_lut, r);
const int32_t gid = lower_bound(data_anc, stride_lut, stride_lut * 2, g);
const int32_t bid = lower_bound(data_anc, stride_lut * 2, stride_lut * 3, b);
/* utility variables for indexing */
const int stride_lut_2 = stride_lut * stride_lut;
const int stride_lut_3 = stride_lut_2 * stride_lut;
/* retrieve the interpolation verticess (number of 8 in case of trilinear interpolation) */
const int id000 = (rid ) + stride_lut * (gid ) + stride_lut_2 * (bid );
const int id100 = (rid + 1) + stride_lut * (gid ) + stride_lut_2 * (bid );
const int id010 = (rid ) + stride_lut * (gid + 1) + stride_lut_2 * (bid );
const int id110 = (rid + 1) + stride_lut * (gid + 1) + stride_lut_2 * (bid );
const int id001 = (rid ) + stride_lut * (gid ) + stride_lut_2 * (bid + 1);
const int id101 = (rid + 1) + stride_lut * (gid ) + stride_lut_2 * (bid + 1);
const int id011 = (rid ) + stride_lut * (gid + 1) + stride_lut_2 * (bid + 1);
const int id111 = (rid + 1) + stride_lut * (gid + 1) + stride_lut_2 * (bid + 1);
/* compute interpolation weights */
const scalar_t r0 = data_anc[rid];
const scalar_t r1 = data_anc[rid + 1];
const scalar_t g0 = data_anc[gid + stride_lut];
const scalar_t g1 = data_anc[gid + stride_lut + 1];
const scalar_t b0 = data_anc[bid + stride_lut * 2];
const scalar_t b1 = data_anc[bid + stride_lut * 2 + 1];
const scalar_t rd = (r - r0) / (r1 - r0 + eps);
const scalar_t gd = (g - g0) / (g1 - g0 + eps);
const scalar_t bd = (b - b0) / (b1 - b0 + eps);
const scalar_t w000 = (1 - rd) * (1 - gd) * (1 - bd);
const scalar_t w100 = ( rd) * (1 - gd) * (1 - bd);
const scalar_t w010 = (1 - rd) * ( gd) * (1 - bd);
const scalar_t w110 = ( rd) * ( gd) * (1 - bd);
const scalar_t w001 = (1 - rd) * (1 - gd) * ( bd);
const scalar_t w101 = ( rd) * (1 - gd) * ( bd);
const scalar_t w011 = (1 - rd) * ( gd) * ( bd);
const scalar_t w111 = ( rd) * ( gd) * ( bd);
/* Execute the interpolation */
for (int i = 0; i < num_channels; ++i) {
data_col[index + height * width * i] =
w000 * data_lut[id000 + stride_lut_3 * i] + w100 * data_lut[id100 + stride_lut_3 * i] +
w010 * data_lut[id010 + stride_lut_3 * i] + w110 * data_lut[id110 + stride_lut_3 * i] +
w001 * data_lut[id001 + stride_lut_3 * i] + w101 * data_lut[id101 + stride_lut_3 * i] +
w011 * data_lut[id011 + stride_lut_3 * i] + w111 * data_lut[id111 + stride_lut_3 * i];
}
}
}
可以看到其在空间上进行迭代,以一次迭代为例,首先去到r、g、b的值,然后在lut中查询的下界,使用的是lower_bound,二分查找,实现如下所示。
/* std::clamp is only available since c++17 */
template <typename scalar_t>
inline constexpr const scalar_t& clamp(
const scalar_t& v, const scalar_t& lo, const scalar_t& hi)
{
return (v < lo) ? lo : ((v > hi) ? hi : v);
}
/* binary search on a sorted array to find and clamp the lower bound */
template <typename scalar_t>
inline int32_t lower_bound(
const scalar_t *data_ss,
int32_t start,
int32_t end,
scalar_t val) {
const int32_t ori_start = start;
const int32_t upper_bound = end - start - 2;
while (start < end) {
int64_t mid = start + ((end - start) >> 1);
if (!(data_ss[mid] >= val)) {
start = mid + 1;
}
else {
end = mid;
}
}
return clamp(start - ori_start - 1, 0, upper_bound);
}
因为lut是递增的,因此可以根据当前r、g、b的值二分找到下界对应的采样点。回到插值过程,当我们找到下界采样点后,就可以通过加1和组合找到所有的8个点了,之后进行插值就可以得到实际结果。
3、总结
代码实现核心的部分讲解完毕,该篇论文相对来说开源的质量就比4DLUT会高很多。不过这里需要指出一个点:大家看了实现后就会发现,该篇论文的自适应间隔,应该理解为先对原始r、g、b做一次映射,然后再进行3DLUT插值,如果将其理解为3DLUT自适应间隔减小量化误差,就跟实际实现不对应。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

6705

被折叠的 条评论
为什么被折叠?



