这是我参与「第三届青训营 -后端场」笔记创作活动的第3篇笔记
图搜图采用wukong50k的数据集
为了便于图片的训练环节,因此写了一个多线程爬虫爬取了所有的图片,其中有部分图片的链接已经失效,因此将这些图片从数据集中去除了。
phash方案
第一版的方案为每个图像生成对应的pHash,之后利用对应的pHash来进行相似图检索,不过该方案准确率不足,故采用第二版方案
resnet方案
第二版方案为本次项目的方案,基于经典CNN网络:resnet18网络结构来提取图像特征
项目基于python3.8、pytorch1.10
首先我们基于torchvision.transforms对图像进行预处理,将传入的图像尺寸修改为(512, 512),再转换为tensor,之后逐channel的对图像进行标准化,这样就完成了图像预处理环节。
接下来我们基于torch.utils.data.DataLoader自定义我们图像的datalodaer,其中我们设置batch_size为40,这个具体的参数可以由自己根据自己的电脑配置和实际环境来设置。
之后我们定义我们的embeddingNet,也就是我们本次做图搜图的重点,其实这只是将每个图片基于resnet18提取的特征转换为高纬度的向量表示(也就是把图像转换为向量)
现在我们定义好我们的网络了就可以开始训练啦!推荐将数据加载到GPU上去进行训练(cpu训练会很慢)
训练完成后我们采用L2归一化,将所有数据变成0-1之间,这样训练阶段的工作我们就做完了(我们将模型保存到本地,模型为.npy文件(保留所有训练图片的高纬特征向量),下次我们需要使用的时候直接导入训练好的模型就可以了)
接下来我们开始相似图搜索阶段,我们把想要搜索的照片传到我们的服务器上,之后同训练步骤将其转换为高纬度的向量后,将其与我们存为.npy的模型做矩阵乘,这时候我们得到了该图片和所有图片的高纬度空间相似值(越靠近1说明越相似),为了便于后面运算,我们直接筛选出所有>0.9的图片,将其排序后进行输出。
补充说明:由于图片爬取的过程中有部分图片失效,但是我们的embedding矩阵行序号是连续的,因此这里需要留意不要出错。