复现系列-3:5分钟搭建一个粗粒度「视频去重」系统

795 阅读3分钟

本期还是跟着Zilliz技术公众号,尝试一些有趣的东西。

视频去重,一共分两篇,本篇是粗粒度的,一个视频整体的查重,下一篇是细粒度的,定位到视频内的片段。

在照着帖子复现之前,我简单地想了想,如果把这个题目放在10年以前,应该怎么去完成呢。

那时候做图像相关的,还只有OpenCV。要做去重,那肯定就得先看看是否相似。降维到图片的相似度判断上,当时有SIFT算法等,尺度、光照、旋转变换各种情况下的关键点映射计算啥的。有点记不清了,但是能感受得到,计算复杂度是很高的。

现在是AI时代,好处就是,一言不合上模型。

复现步骤开始:

1. 依赖包安装

pip3 install pymilvus towhee pillow pandas ipython -i https://pypi.tuna.tsinghua.edu.cn/simple

使用清华的pip源安装,比pip原始的源会快很多,也可以自行采用阿里的源。

2. 准备数据集

提前下载好需要用到的所有数据

curl -L https://github.com/towhee-io/examples/releases/download/data/VCDB_core_sample.zip -O  
unzip -q -o VCDB_core_sample.zip
axel -n 10 https://download.pytorch.org/models/resnet50-0676ba61.pth 
axel -n 10 http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth
axel -n 10 https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth
cp *.pth ~/.cache/torch/hub/checkpoints/

这里如果没有axel命令也可以用wget

3. 浏览一下数据

import random
from pathlib import Path
import torch
import pandas as pd
random.seed(6)

root_dir = './VCDB_core_sample'


min_sample_num = 5
sample_folder_num = 20

all_video_path_lists = []
all_video_path_list = []

df = pd.DataFrame(columns=('path','event','id'))
query_df = pd.DataFrame(columns=('path','event','id'))

video_idx = 0
for i, mid_dir_path in enumerate(Path(root_dir).iterdir()):
    if i >= sample_folder_num:
        break
    if mid_dir_path.is_dir():
        path_videos = list(Path(mid_dir_path).iterdir())
        if len(path_videos) < min_sample_num:
            print('len(path_videos) < min_sample_num, continue.')
            continue
        sample_video_path_list = random.sample(path_videos, min_sample_num)
        all_video_path_lists.append(sample_video_path_list)
        all_video_path_list += [str(path) for path in sample_video_path_list]
        for j, path in enumerate(sample_video_path_list):
            video_idx += 1
            if j == 0:
                query_df = query_df.append(pd.DataFrame({'path': [str(path)],'event':[path.parent.stem],'id': [video_idx]}),ignore_index=True)
            df = df.append(pd.DataFrame({'path': [str(path)],'event':[path.parent.stem],'id': [video_idx]}),ignore_index=True)

all_sample_video_dicts = []
for i, sample_video_path_list in enumerate(all_video_path_lists):
    anchor_video = sample_video_path_list[0]
    pos_video_path_list = sample_video_path_list[1:]
    neg_video_path_lists = all_video_path_lists[:i] + all_video_path_lists[i + 1:]
    neg_video_path_list = [neg_video_path_list[0] for neg_video_path_list in neg_video_path_lists]
    all_sample_video_dicts.append({
        'anchor_video': anchor_video,
        'pos_video_path_list': pos_video_path_list,
        'neg_video_path_list': neg_video_path_list
    })

id2event = df.set_index(['id'])['event'].to_dict()
id2path = df.set_index(['id'])['path'].to_dict()

df_csv_path = 'video_info.csv'
query_df_csv_path = 'query_video_info.csv'
df.to_csv(df_csv_path)
query_df.to_csv(query_df_csv_path)
df

如果df的结果是这样子的,基本就符合预期没什么问题,可以进行到下一步了

image.png

4. 以gif呈现

from IPython import display
from pathlib import Path
import towhee
from PIL import Image

def display_gif(video_path_list, text_list):
    html = ''
    for video_path, text in zip(video_path_list, text_list):
        html_line = '<img src=\"{}\"> {} <br/><br/>'.format(video_path, text)
        html += html_line
    return display.HTML(html)

    
def convert_video2gif(video_path, output_gif_path, num_samples=16):
    frames = (
        towhee.glob(video_path)
              .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 5})
              .to_list()[0]
    )
    imgs = [Image.fromarray(frame) for frame in frames]
    imgs[0].save(fp=output_gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)


def display_gifs_from_video(video_path_list, text_list, tmpdirname = './tmp_gifs'):
    Path(tmpdirname).mkdir(exist_ok=True)
    gif_path_list = []
    for video_path in video_path_list:
        video_name = str(Path(video_path).name).split('.')[0]
        gif_path = Path(tmpdirname) / (video_name + '.gif')
        convert_video2gif(video_path, gif_path)
        gif_path_list.append(gif_path)
    return display_gif(gif_path_list, text_list)

random_video_pair = random.sample(all_sample_video_dicts, 1)[0]
neg_sample_num = min(5, sample_folder_num)
anchor_video = random_video_pair['anchor_video']
anchor_video_event = anchor_video.parent.stem
pos_video_list = random_video_pair['pos_video_path_list']
pos_video_list_events = [path.parent.stem for path in pos_video_list]
neg_video_list = random_video_pair['neg_video_path_list'][:neg_sample_num]
neg_video_list_events = [path.parent.stem for path in neg_video_list]

show_video_list = [str(anchor_video)] + [str(path) for path in pos_video_list] + [str(path) for path in neg_video_list]
# print(show_video_list)
caption_list = ['anchor video: ' + anchor_video_event] + ['positive video ' + str(i + 1) for i in range(len(pos_video_list))] + ['negative video ' + str(i + 1) + ': ' + neg_video_list_events[i] for i in range(len(neg_video_list))]
print(caption_list)
tmpdirname = './tmp_gifs'
display_gifs_from_video(show_video_list, caption_list, tmpdirname=tmpdirname)

image.png 在M1的Mac上会有ffmpeg报出来的Warning,可以直接忽略掉。如果是x86 CPU上,大概率是没有这些Warning的。问题不大。

这里尝试以麦迪35秒13分的天神下凡时刻作为锚点视频,列出了Positive 和 Negative 的备选视频。

20221215215859_rec_.gif

5. 创建Milvus集合

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='127.0.0.1', port='19530')

def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video deduplication')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection
    
collection = create_milvus_collection('video_deduplication', 1024)

创建好了collection之后,就是最耗时的一步了。这里是在jupyter中运行的,因此采用了jupyter的语法来看耗时,当然也可以用time.time()来计算。

%%time
import os
import towhee
from towhee import dc
device = 'cpu'

dc = (
    towhee.read_csv(df_csv_path).unstream() \
        .runas_op['id', 'id'](func=lambda x: int(x)) \
        .video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=60.0, sample_type='time_step_sample', args={'time_step': 1}) \
        .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
        .distill_and_select['frames', 'vec'](model_name='cg_student', device=device) \
        .to_milvus['id', 'vec'](collection=collection, batch=30)
)

不出意外的,这一步超级慢,使用CPU处理,一共花费了12分钟多,如果这里之前没有下载好模型文件,下载也会卡很久。

image.png

但假如你有一张N卡,就可以更改为 device='cuda' 会超级快。

6. 查询

%%time
dc = (
    towhee.read_csv(query_df_csv_path).unstream() \
      .runas_op['event', 'ground_truth_event'](func=lambda x:[x]) \
      .video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=60.0, sample_type='time_step_sample', args={'time_step': 1}) \
      .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
      .distill_and_select['frames', 'vec'](model_name='cg_student', device=device) \
      .milvus_search['vec', 'topk_raw_res'](collection=collection, limit=min_sample_num) \
      .runas_op['topk_raw_res', 'topk_events'](func=lambda res: [id2event[x.id] for i, x in enumerate(res)]) \
      .runas_op['topk_raw_res', 'topk_path'](func=lambda res: [id2path[x.id] for i, x in enumerate(res)])
)
dc.select['id', 'ground_truth_event', 'topk_raw_res', 'topk_events', 'topk_path']().show()

image.png

dc_list = dc.to_list()
# random_idx = random.randint(0, len(dc_list) - 1)
sample_num = 3
sample_idxs = random.sample(range(len(dc_list)), sample_num)
def get_query_and_predict_videos(idx):
    query_video = id2path[int(dc_list[idx].id)]
    print('query_video =', query_video)
    predict_topk_video_list = dc_list[idx].topk_path[1:]
    print('predict_topk_video_list =', predict_topk_video_list)
    return query_video, predict_topk_video_list
dsp_res_list = []
for idx in sample_idxs:
    query_video, predict_topk_video_list = get_query_and_predict_videos(idx)
    show_video_list = [query_video] + predict_topk_video_list
    caption_list = ['query video: ' + Path(query_video).parent.stem] + ['result{0} video'.format(i) for i in range(len(predict_topk_video_list))]
    dsp_res_list.append(display_gifs_from_video(show_video_list, caption_list, tmpdirname=tmpdirname))

随机选了3个index来看看相似的视频都有哪些。

20221215220944_rec_.gif

7. Benchmark评估

benchmark = (
    dc.with_metrics(['mean_average_precision',]) \
        .evaluate['ground_truth_event', 'topk_events'](name='map_at_k') \
        .report()
)

image.png 给出了超级高分,也是有点意外。