背景:
古语云:“天下文章一大抄”。其实这就是创作生成的本质,人类的创作并非凭空生成,无源而生。无论是促景生情有感而发,还是看了某篇别人的文章而来了创作灵感,亦或是饱读诗书彻悟大道洋洋洒洒无拘无束的天马星空....无论是何种情况,我们都是在已有的信息、灵感之上做提取、整合、组合、编排、加工。所有的创作目前都未离开这个创作范式。
那么对于AIGC我们是否也可以更直接的提供检索的内容,基于相似的内容来做创作。当然你可以所AIGC的大模型本身就是对各种数据的提取、整合、加工把信息在高维度重构了,它就是一个集内容、方法于一体的先验人类知识的分布,何必再多此一举搞一个检索内容生成呢。
其实原因也很简单,AIGC的大模型虽说是对信息提取、整合、加工了。然而这个数据毕竟是有限的,并且数据加工是有偏好的,很可能模型并不能够按用户习惯的模式把需要的数据表达、描述。你们的话语体系是不一样的,然而检索是更底层和已经融入大部份人思维习惯的工具,所以能够相对更容易和准确统一表达话语体系。这就是AIGC这个阶段可以辅以内容检索来生成基本假设。
正文:
这部分介绍的是如何搭建一个基于内容的图检索系统,实现功能有三:
1.基于文本描述来检索图
2.基于图来检索相似图
3.支持中英文双语检索
说是系统,然得益于多模态技术的发展实则几百行代码,几个文件而已。然别看代码不多,系统还功能还是很强大的,支持百万的图片的秒级别检索,检索准确度也非常高。也就是说这个系统虽然简单,但是却是工业级别的系统。鲁棒性是够的,如果你要做更高QPS和更精准的分级检索大部份代码应该是分布式、工程系统层面代码改造。
图的检索,有一下几类:
1.基于meta data的检索,直白点就是把图人为打上各种标签的标签检索
2.图特征检索:色系、直方图、复杂度、尺寸比例......
3.基于内容的图检索,包括图-文(长、短)语义embbeding
我们这篇文章主要是基于clip来实现基于内容的图检索,也会简单介绍基于dinov2的基于内容检索。两个模型的不是这部分重点,所以不会做介绍,直接了当介绍如何实现。
特征抽取部分代码:
1.把图数据批量读入,构建dataset
2.构建特征抽取预处理和模型
3.并行化抽取特征
def compute_embeddings(img_dir, save_path, batch_size, num_workers):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
device = "cuda" if torch.cuda.is_available() else "cpu"
#2.构建预处理和模型链路
model, preprocess = load_from_name("ViT-H-14", device=device, download_root='./')
model.eval()
#1.构建dataset
dataset = ClipSearchDataset(img_dir = img_dir, preprocess = preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
#特征抽取
img_path_list, embedding_list = [], []
for img, img_path in tqdm(dataloader):
with torch.no_grad():
features = model.encode_image(img.to(device))
features /= features.norm(dim=-1, keepdim=True)
embedding_list.extend(features.detach().cpu().numpy())
img_path_list.extend(img_path)
result = {'img_path': img_path_list, 'embedding': embedding_list}
with open(save_path, 'wb') as f:
pickle.dump(result, f, protocol=4)
特征数据存入faiss
把clip抽取的图特征存入faiss向量数据库,方便后面快速便捷的检索。
def create_faiss_index(embeddings_path, save_path):
with open(embeddings_path, 'rb') as f:
results = pickle.load(f)
embeddings = np.array(results['embedding'], dtype=np.float32)
index = faiss.index_factory(embeddings.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
index.add(embeddings)
# save index
faiss.write_index(index, save_path)
相似度计算
利用clip模型把传入的图或者文做为输入,抽取出图、文本embedding特征。然后利用embbding特征到存入faiss向量库里面捞出相似度做高的top数据。得到对应数据的图的文件保存路径。
底下代码介绍了,如何通过图或者文作为出入抽取embbding,并给出了简单相似度计算实现。
import torch
from PIL import Image
import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
device = "cuda" if torch.cuda.is_available() else "cpu"
#model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model, preprocess = load_from_name("ViT-H-14", device=device, download_root='./')
model.eval()
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# 对特征进行归一化,请使用归一化后的图文特征用于下游任务
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
logits_per_image, logits_per_text = model.get_similarity(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]
结果可视化
把检索回来的图用streamlit做可视化。

if search_mode == 'Image':
# search by image
img = Image.open(img_path).convert('RGB')
st.image(img, caption=f'Query Image: {img_path}')
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
features = model.encode_image(img_tensor.to(device))
elif search_mode == 'Upload Image':
uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
#img = Image.open("../image/train/left/00000.jpg").convert('RGB')
img = Image.open("examples/pokemon.jpeg").convert('RGB')
if uploaded_file is not None:
img = Image.open(uploaded_file).convert('RGB')
st.image(img)
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
features = model.encode_image(img_tensor.to(device))
else:
# search by text
query_text = st.text_input('Enter a search term:')
with torch.no_grad():
text = clip.tokenize([query_text]).to(device)
features = model.encode_text(text)
dinov2特征抽取
基于dinov2的特征抽取,图文特征的抽取,出了用clip外,还可以用facebook家的dinov2来实现。下面给出的是单张图抽取方式,批量类似上面dataset并行化批量抽取方式可以直接到我git项目中找代码。
def compute_embeddings(img_dir, save_path, batch_size, num_workers):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
#Load DINOv2 model and processor
processor_dino = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model_dino = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
#Retrieve all filenames
images = []
for root, dirs, files in os.walk(img_dir):
for file in files:
if file.endswith('jpg'):
images.append(root + '/'+ file)
index_dino = faiss.IndexFlatL2(768)
#Iterate over the dataset to extract features X2 and store features in indexes
for image_path in images:
img = Image.open(image_path).convert('RGB')
dino_features = extract_features_dino(img,processor_dino,model_dino,device)
add_vector_to_index(dino_features,index_dino)
chinese-clip
本项目使用额是OFA开源的chinese-clip底模来实现的,如果私域数据比较多的开源底模特征抽取效果不好,可以利用自己的数据对模型做二次预训练,单卡就能调的动;巨头实现可以参看chinese-clip项目代码。
chinese-clip还给出了onnx、TensorRT的推理加速,需要的可以执行他们脚本,按步骤执行就行。个人觉得数据量不大情况,torch版本实现就已经可以满足要求,检索推理优化速度差异不大,对于图库特征抽取如faiss时间会有一定优化。
更多代码细节可以到我git上看:
github.com/liangwq/Cha…
小结
基于内容的图检索是基于检索的AIGC图创作的基础模块,文章介绍了基于内容的图检索。给出了两种多模态模型的实现方式:chinese-clip和dinov2实现。并给出了代码实现基本架构和一个实现代码例子。文章内容偏工程实现,所以没有太多理论介绍,具体看项目代码即可。