目录
1.项目起因
通过爬虫爬取了一些网站的gif图片,结果等全部爬完后发现图片太多了,根本没工夫一张一张去看。于是打算把这些图片做成一个本地图库,方便需要的时候按照内容检索(主要是有时候找不到来源)。
2.总体思路
- 使用神经网络对图片进行特征提取,打算直接采用预训练的ResNet作为特征提取器;
- 搜索的时候,通过计算特征向量之间的余弦相似度来确定两张gif之间的相似度。
3.用到的环境和包
- Python 3.9;
- Pytorch 1.8;
- Opencv;
- Numpy
4.具体实现
1.Gif预处理
由于gif图片可以看成是由很多张图片叠在一起的,所以在特征提取的时候我的想法是先按照等间距取一定帧数的图片(我的数据里帧数最少的只有5帧,因此我每张图片都等间隔的取了5帧,特殊情况的用最后一帧补齐5帧),再把这若干张图片叠成一个batch输入到网络中,最后的输出看作是整个gif的特征。
在实际操作中,发现不能直接通过opencv的imread来读取gif图片,要使用VideoCapture像读取视频文件一样按帧读取。而且我在使用中还发现,使用 CAP_PROP_FRAME_COUNT 不能返回gif的帧数,会返回负数(不知道是我环境的原因还是其他的)。
def gif_split_to(gpath: str, fstep: int) -> list:
'''
分割gif图片并挑选特定帧数
:param gpath: gif路径
:param fstep: 目标帧数
:return: 一个列表,包含fstep个帧数
'''
# 获取所有帧数
frames = []
cap = cv2.VideoCapture(gpath)
ret, frame = cap.read()
while ret:
frames.append(frame)
ret, frame = cap.read()
cap.release()
fnum = len(frames)
step_frame = math.ceil(fnum / fstep)
# 防止步长大于帧数总数
if step_frame <= 0:
step_frame = 1
ret = list()
# 等间距取帧
for idx in range(0, len(frames)):
if idx % step_frame == 0 and ret:
frame = cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB)
ret.append(frame)
# 重复最后一帧补齐到目标帧数
while len(ret) < g_gif_need:
ret.append(frames[len(frames) - 1])
frames = None
return ret
再把图片列表转为tensor,就可以输入到网络里面了:
def gif_to_tensor(gpath: str, fstep: int) -> torch.FloatTensor:
'''
gif转目标维度tensor
:param gpath:gif路径
:param fstep:取得目标帧数
:return:一个tensor
'''
list_img = gif_split_to(gpath, fstep)
gif_np = [np.array(x).transpose((2, 0, 1)) for x in list_img]
gif_tensor = torch.FloatTensor(np.array(gif_np))
return gif_tensor
2.网络调整
前文提到,使用ResNet作为特征提取器,具体我采用的是ResNet18。
ResNet最后有个全局平均池化和全连接层,是用于分类的,这里只是用它提取特征,所以可以直接去掉。并且我也具体实验了有后三层和没后三层的效果,发现差别还是有的,具体如下:
去掉全连接层和全局平均池化层,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距是很大的:
[['1.gif', 1.0], ['2.gif', 0.6843166351318359], ['3.gif', 0.6693203002214432]]
直接使用原网络,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距不是很明显:
[['1.gif', 1.0], ['2.gif', 0.9977700412273407], ['3.gif', 0.9970085620880127]]
要调整网络,一种方法是加载网络后,使用下面这种语句修改使用层数:
net = torchvision.models.resnet50(pretrained=True)
net = nn.Sequential(*list(resnet_50_s.children())[:-2])
我使用的方法是直接在源代码上修改,具体是找到torchvison里的ResNet实现,单独放到本地项目里,再在forward里面注释掉:
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 注释掉
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
3.下载预训练参数并加载
预训练参数下载地址可以参考:
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
使用方法可以参考下面的语句,这里要注意的是下载的参数要用对应的模型去加载,否则会报错
net = resnet18(pretrained=False).to(device)
net.load_state_dict(torch.load(g_pth_path)) # g_pth_path 为下载的参数保存地址
4.提取特征并保存
前面提到把从gif图片中提取的图片数组作为一个batch输入到网络中(一般net的输入shape要是[batch, channel, h, w]),再把获取到的特征转为numpy,使用np.savez_compressed函数保存,这样保存的数组是经过压缩的,可以节省硬盘空间。
inputs = gif_to_tensor(gif_path, g_gif_need).to(device) # g_gif_need 为每个gif要取的帧数
out = net(inputs)
out_np = out.cpu().numpy()
save_name = str(gif_path).split("\\")[-1].split('.')[0]
np.savez_compressed(f"./tmp/{save_name}", a=out_np)
5.gif搜索
要搜索的gif图片也是要经过前面的预处理到特征提取过程,然后加载本地的特征库,一一比对余弦相似度。余弦相似度的比较函数是参考的,具体作者是谁就不知道了:
def mtx_similar1(arr1: np.ndarray, arr2: np.ndarray) -> float:
'''
计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。
注意有展平操作。
:param arr1:矩阵1
:param arr2:矩阵2
:return:实际是夹角的余弦值,ret = (cos+1)/2
'''
farr1 = arr1.ravel()
len1 = len(farr1)
len2 = len(arr2)
if len1 > len2:
farr1 = farr1[:len2]
else:
arr2 = arr2[:len1]
numer = np.sum(farr1 * arr2)
denom = np.sqrt(np.sum(farr1 ** 2) * np.sum(arr2 ** 2))
similar = numer / denom
return (similar + 1) / 2
5.结束语
程序总体上是粗糙的,还有很多改进的地方,比如:
- gif关键帧的提取;
- gif的尺寸不是一样的,是否要归一化处理;
- 除了余弦相似度,还有其他计算方法吗;
- 更合适的网络模型。
完整代码等整理后考虑放到github上。
6.完整代码【更新】
GitHub - ashortname/localGifSearcher: Build a local GIF feature library for search by image.
本文介绍了一个利用ResNet18特征提取进行本地GIF图像检索的项目。首先,通过预处理将GIF图片按帧抽取并转换为Tensor输入网络;接着,调整ResNet模型,移除最后的全局平均池化和全连接层,以提取特征而非进行分类;然后,下载预训练权重并加载;最后,计算GIF特征并保存,使用余弦相似度进行搜索。项目展示了预处理、特征提取和搜索的过程,并探讨了网络调整的影响。

9341

被折叠的 条评论
为什么被折叠?



