本期还是跟着Zilliz技术公众号,尝试一些有趣的东西。
视频去重,一共分两篇,本篇是粗粒度的,一个视频整体的查重,下一篇是细粒度的,定位到视频内的片段。
在照着帖子复现之前,我简单地想了想,如果把这个题目放在10年以前,应该怎么去完成呢。
那时候做图像相关的,还只有OpenCV。要做去重,那肯定就得先看看是否相似。降维到图片的相似度判断上,当时有SIFT算法等,尺度、光照、旋转变换各种情况下的关键点映射计算啥的。有点记不清了,但是能感受得到,计算复杂度是很高的。
现在是AI时代,好处就是,一言不合上模型。
复现步骤开始:
1. 依赖包安装
pip3 install pymilvus towhee pillow pandas ipython -i https://pypi.tuna.tsinghua.edu.cn/simple
使用清华的pip源安装,比pip原始的源会快很多,也可以自行采用阿里的源。
2. 准备数据集
提前下载好需要用到的所有数据
curl -L https://github.com/towhee-io/examples/releases/download/data/VCDB_core_sample.zip -O
unzip -q -o VCDB_core_sample.zip
axel -n 10 https://download.pytorch.org/models/resnet50-0676ba61.pth
axel -n 10 http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth
axel -n 10 https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth
cp *.pth ~/.cache/torch/hub/checkpoints/
这里如果没有axel命令也可以用wget
3. 浏览一下数据
import random
from pathlib import Path
import torch
import pandas as pd
random.seed(6)
root_dir = './VCDB_core_sample'
min_sample_num = 5
sample_folder_num = 20
all_video_path_lists = []
all_video_path_list = []
df = pd.DataFrame(columns=('path','event','id'))
query_df = pd.DataFrame(columns=('path','event','id'))
video_idx = 0
for i, mid_dir_path in enumerate(Path(root_dir).iterdir()):
if i >= sample_folder_num:
break
if mid_dir_path.is_dir():
path_videos = list(Path(mid_dir_path).iterdir())
if len(path_videos) < min_sample_num:
print('len(path_videos) < min_sample_num, continue.')
continue
sample_video_path_list = random.sample(path_videos, min_sample_num)
all_video_path_lists.append(sample_video_path_list)
all_video_path_list += [str(path) for path in sample_video_path_list]
for j, path in enumerate(sample_video_path_list):
video_idx += 1
if j == 0:
query_df = query_df.append(pd.DataFrame({'path': [str(path)],'event':[path.parent.stem],'id': [video_idx]}),ignore_index=True)
df = df.append(pd.DataFrame({'path': [str(path)],'event':[path.parent.stem],'id': [video_idx]}),ignore_index=True)
all_sample_video_dicts = []
for i, sample_video_path_list in enumerate(all_video_path_lists):
anchor_video = sample_video_path_list[0]
pos_video_path_list = sample_video_path_list[1:]
neg_video_path_lists = all_video_path_lists[:i] + all_video_path_lists[i + 1:]
neg_video_path_list = [neg_video_path_list[0] for neg_video_path_list in neg_video_path_lists]
all_sample_video_dicts.append({
'anchor_video': anchor_video,
'pos_video_path_list': pos_video_path_list,
'neg_video_path_list': neg_video_path_list
})
id2event = df.set_index(['id'])['event'].to_dict()
id2path = df.set_index(['id'])['path'].to_dict()
df_csv_path = 'video_info.csv'
query_df_csv_path = 'query_video_info.csv'
df.to_csv(df_csv_path)
query_df.to_csv(query_df_csv_path)
df
如果df的结果是这样子的,基本就符合预期没什么问题,可以进行到下一步了
4. 以gif呈现
from IPython import display
from pathlib import Path
import towhee
from PIL import Image
def display_gif(video_path_list, text_list):
html = ''
for video_path, text in zip(video_path_list, text_list):
html_line = '<img src=\"{}\"> {} <br/><br/>'.format(video_path, text)
html += html_line
return display.HTML(html)
def convert_video2gif(video_path, output_gif_path, num_samples=16):
frames = (
towhee.glob(video_path)
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 5})
.to_list()[0]
)
imgs = [Image.fromarray(frame) for frame in frames]
imgs[0].save(fp=output_gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)
def display_gifs_from_video(video_path_list, text_list, tmpdirname = './tmp_gifs'):
Path(tmpdirname).mkdir(exist_ok=True)
gif_path_list = []
for video_path in video_path_list:
video_name = str(Path(video_path).name).split('.')[0]
gif_path = Path(tmpdirname) / (video_name + '.gif')
convert_video2gif(video_path, gif_path)
gif_path_list.append(gif_path)
return display_gif(gif_path_list, text_list)
random_video_pair = random.sample(all_sample_video_dicts, 1)[0]
neg_sample_num = min(5, sample_folder_num)
anchor_video = random_video_pair['anchor_video']
anchor_video_event = anchor_video.parent.stem
pos_video_list = random_video_pair['pos_video_path_list']
pos_video_list_events = [path.parent.stem for path in pos_video_list]
neg_video_list = random_video_pair['neg_video_path_list'][:neg_sample_num]
neg_video_list_events = [path.parent.stem for path in neg_video_list]
show_video_list = [str(anchor_video)] + [str(path) for path in pos_video_list] + [str(path) for path in neg_video_list]
# print(show_video_list)
caption_list = ['anchor video: ' + anchor_video_event] + ['positive video ' + str(i + 1) for i in range(len(pos_video_list))] + ['negative video ' + str(i + 1) + ': ' + neg_video_list_events[i] for i in range(len(neg_video_list))]
print(caption_list)
tmpdirname = './tmp_gifs'
display_gifs_from_video(show_video_list, caption_list, tmpdirname=tmpdirname)
在M1的Mac上会有ffmpeg报出来的Warning,可以直接忽略掉。如果是x86 CPU上,大概率是没有这些Warning的。问题不大。
这里尝试以麦迪35秒13分的天神下凡时刻作为锚点视频,列出了Positive 和 Negative 的备选视频。
5. 创建Milvus集合
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect(host='127.0.0.1', port='19530')
def create_milvus_collection(collection_name, dim):
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='video deduplication')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2', #IP
'index_type':"IVF_FLAT",
'params':{"nlist":2048}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('video_deduplication', 1024)
创建好了collection之后,就是最耗时的一步了。这里是在jupyter中运行的,因此采用了jupyter的语法来看耗时,当然也可以用time.time()来计算。
%%time
import os
import towhee
from towhee import dc
device = 'cpu'
dc = (
towhee.read_csv(df_csv_path).unstream() \
.runas_op['id', 'id'](func=lambda x: int(x)) \
.video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=60.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
.distill_and_select['frames', 'vec'](model_name='cg_student', device=device) \
.to_milvus['id', 'vec'](collection=collection, batch=30)
)
不出意外的,这一步超级慢,使用CPU处理,一共花费了12分钟多,如果这里之前没有下载好模型文件,下载也会卡很久。
但假如你有一张N卡,就可以更改为 device='cuda' 会超级快。
6. 查询
%%time
dc = (
towhee.read_csv(query_df_csv_path).unstream() \
.runas_op['event', 'ground_truth_event'](func=lambda x:[x]) \
.video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=60.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
.distill_and_select['frames', 'vec'](model_name='cg_student', device=device) \
.milvus_search['vec', 'topk_raw_res'](collection=collection, limit=min_sample_num) \
.runas_op['topk_raw_res', 'topk_events'](func=lambda res: [id2event[x.id] for i, x in enumerate(res)]) \
.runas_op['topk_raw_res', 'topk_path'](func=lambda res: [id2path[x.id] for i, x in enumerate(res)])
)
dc.select['id', 'ground_truth_event', 'topk_raw_res', 'topk_events', 'topk_path']().show()
dc_list = dc.to_list()
# random_idx = random.randint(0, len(dc_list) - 1)
sample_num = 3
sample_idxs = random.sample(range(len(dc_list)), sample_num)
def get_query_and_predict_videos(idx):
query_video = id2path[int(dc_list[idx].id)]
print('query_video =', query_video)
predict_topk_video_list = dc_list[idx].topk_path[1:]
print('predict_topk_video_list =', predict_topk_video_list)
return query_video, predict_topk_video_list
dsp_res_list = []
for idx in sample_idxs:
query_video, predict_topk_video_list = get_query_and_predict_videos(idx)
show_video_list = [query_video] + predict_topk_video_list
caption_list = ['query video: ' + Path(query_video).parent.stem] + ['result{0} video'.format(i) for i in range(len(predict_topk_video_list))]
dsp_res_list.append(display_gifs_from_video(show_video_list, caption_list, tmpdirname=tmpdirname))
随机选了3个index来看看相似的视频都有哪些。
7. Benchmark评估
benchmark = (
dc.with_metrics(['mean_average_precision',]) \
.evaluate['ground_truth_event', 'topk_events'](name='map_at_k') \
.report()
)
给出了超级高分,也是有点意外。