AIGC之图片生成——基于clip内容检索

284 阅读6分钟

项目代码:github.com/liangwq/Cha…

背景:

古语云:“天下文章一大抄”。其实这就是创作生成的本质,人类的创作并非凭空生成,无源而生。无论是促景生情有感而发,还是看了某篇别人的文章而来了创作灵感,亦或是饱读诗书彻悟大道洋洋洒洒无拘无束的天马星空....无论是何种情况,我们都是在已有的信息、灵感之上做提取、整合、组合、编排、加工。所有的创作目前都未离开这个创作范式。
那么对于AIGC我们是否也可以更直接的提供检索的内容,基于相似的内容来做创作。当然你可以所AIGC的大模型本身就是对各种数据的提取、整合、加工把信息在高维度重构了,它就是一个集内容、方法于一体的先验人类知识的分布,何必再多此一举搞一个检索内容生成呢。
其实原因也很简单,AIGC的大模型虽说是对信息提取、整合、加工了。然而这个数据毕竟是有限的,并且数据加工是有偏好的,很可能模型并不能够按用户习惯的模式把需要的数据表达、描述。你们的话语体系是不一样的,然而检索是更底层和已经融入大部份人思维习惯的工具,所以能够相对更容易和准确统一表达话语体系。这就是AIGC这个阶段可以辅以内容检索来生成基本假设。

正文:

这部分介绍的是如何搭建一个基于内容的图检索系统,实现功能有三:
1.基于文本描述来检索图
2.基于图来检索相似图
3.支持中英文双语检索
说是系统,然得益于多模态技术的发展实则几百行代码,几个文件而已。然别看代码不多,系统还功能还是很强大的,支持百万的图片的秒级别检索,检索准确度也非常高。也就是说这个系统虽然简单,但是却是工业级别的系统。鲁棒性是够的,如果你要做更高QPS和更精准的分级检索大部份代码应该是分布式、工程系统层面代码改造。
图的检索,有一下几类:
1.基于meta data的检索,直白点就是把图人为打上各种标签的标签检索
2.图特征检索:色系、直方图、复杂度、尺寸比例......
3.基于内容的图检索,包括图-文(长、短)语义embbeding
我们这篇文章主要是基于clip来实现基于内容的图检索,也会简单介绍基于dinov2的基于内容检索。两个模型的不是这部分重点,所以不会做介绍,直接了当介绍如何实现。

特征抽取部分代码:

1.把图数据批量读入,构建dataset
2.构建特征抽取预处理和模型
3.并行化抽取特征

def compute_embeddings(img_dir, save_path, batch_size, num_workers):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    #2.构建预处理和模型链路
    model, preprocess = load_from_name("ViT-H-14", device=device, download_root='./')
    model.eval()
    #1.构建dataset
    dataset = ClipSearchDataset(img_dir = img_dir, preprocess = preprocess)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    #特征抽取
    img_path_list, embedding_list = [], []
    for img, img_path in tqdm(dataloader):
        with torch.no_grad():
            features = model.encode_image(img.to(device))
            features /= features.norm(dim=-1, keepdim=True)
            embedding_list.extend(features.detach().cpu().numpy())
            img_path_list.extend(img_path)

    result = {'img_path': img_path_list, 'embedding': embedding_list}
    with open(save_path, 'wb') as f:
        pickle.dump(result, f, protocol=4)

特征数据存入faiss

把clip抽取的图特征存入faiss向量数据库,方便后面快速便捷的检索。

def create_faiss_index(embeddings_path, save_path):
    with open(embeddings_path, 'rb') as f:
        results = pickle.load(f)

    embeddings = np.array(results['embedding'], dtype=np.float32)

    index = faiss.index_factory(embeddings.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT)
    index.add(embeddings)
    # save index
    faiss.write_index(index, save_path)

相似度计算

利用clip模型把传入的图或者文做为输入,抽取出图、文本embedding特征。然后利用embbding特征到存入faiss向量库里面捞出相似度做高的top数据。得到对应数据的图的文件保存路径。
底下代码介绍了,如何通过图或者文作为出入抽取embbding,并给出了简单相似度计算实现。

import torch 
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())  
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']

device = "cuda" if torch.cuda.is_available() else "cpu"
#model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model, preprocess = load_from_name("ViT-H-14", device=device, download_root='./')
model.eval()
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    # 对特征进行归一化,请使用归一化后的图文特征用于下游任务
    image_features /= image_features.norm(dim=-1, keepdim=True) 
    text_features /= text_features.norm(dim=-1, keepdim=True)    

    logits_per_image, logits_per_text = model.get_similarity(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]

结果可视化

把检索回来的图用streamlit做可视化。
image.png
image.png

if search_mode == 'Image':
    # search by image
    img = Image.open(img_path).convert('RGB')
    st.image(img, caption=f'Query Image: {img_path}')
    img_tensor = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model.encode_image(img_tensor.to(device))
elif search_mode == 'Upload Image':
    uploaded_file = st.file_uploader("Choose an image...", type=['jpg', 'jpeg', 'png'])
    #img = Image.open("../image/train/left/00000.jpg").convert('RGB')
    img = Image.open("examples/pokemon.jpeg").convert('RGB')
    if uploaded_file is not None:
        img = Image.open(uploaded_file).convert('RGB')
        st.image(img)
    img_tensor = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model.encode_image(img_tensor.to(device))
else:
    # search by text
    query_text = st.text_input('Enter a search term:')
    with torch.no_grad():
        text = clip.tokenize([query_text]).to(device)
        features = model.encode_text(text)

dinov2特征抽取

基于dinov2的特征抽取,图文特征的抽取,出了用clip外,还可以用facebook家的dinov2来实现。下面给出的是单张图抽取方式,批量类似上面dataset并行化批量抽取方式可以直接到我git项目中找代码。

def compute_embeddings(img_dir, save_path, batch_size, num_workers):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


    #Load DINOv2 model and processor
    processor_dino = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
    model_dino = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
    #Retrieve all filenames
    images = []
    for root, dirs, files in os.walk(img_dir):
        for file in files:
            if file.endswith('jpg'):
                images.append(root + '/'+ file)

    index_dino = faiss.IndexFlatL2(768)

    #Iterate over the dataset to extract features X2 and store features in indexes
    for image_path in images:
        img = Image.open(image_path).convert('RGB')
        dino_features = extract_features_dino(img,processor_dino,model_dino,device)
        add_vector_to_index(dino_features,index_dino)

chinese-clip

本项目使用额是OFA开源的chinese-clip底模来实现的,如果私域数据比较多的开源底模特征抽取效果不好,可以利用自己的数据对模型做二次预训练,单卡就能调的动;巨头实现可以参看chinese-clip项目代码。
chinese-clip还给出了onnx、TensorRT的推理加速,需要的可以执行他们脚本,按步骤执行就行。个人觉得数据量不大情况,torch版本实现就已经可以满足要求,检索推理优化速度差异不大,对于图库特征抽取如faiss时间会有一定优化。
更多代码细节可以到我git上看:
github.com/liangwq/Cha…

小结

基于内容的图检索是基于检索的AIGC图创作的基础模块,文章介绍了基于内容的图检索。给出了两种多模态模型的实现方式:chinese-clip和dinov2实现。并给出了代码实现基本架构和一个实现代码例子。文章内容偏工程实现,所以没有太多理论介绍,具体看项目代码即可。