使用CLIP获取向量存入ES

13 阅读1分钟
import time
import torch
import clip
import os
from PIL import Image
from elasticsearch import Elasticsearch
from datetime import datetime

es_client = Elasticsearch("http://127.0.0.1:9200/")
index_name = "autodb_image_demorn50x64"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model, preprocess = clip.load("RN50x64", device=device)


def get_img_feature(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        return image_features


def get_text_feature(text_msg):
    text = clip.tokenize([text_msg]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)
        return text_features


def insert(desc, dense):
    numpy_array = dense.numpy()
    list = numpy_array[0].tolist()
    doc = {
        "desc": desc,
        "img_vector": list,
        "timestamp": datetime.now(),
    }
    resp = es_client.index(index=index_name, document=doc)
    print(resp["result"])


def select(dense, k):
    numpy_array = dense.numpy()
    list = numpy_array[0].tolist()
    body = {
        "query": {
            "knn": {
                "field": "img_vector",
                "query_vector": list,
                "k": "10"
            }
        },
        "sort": {"_score": {"order": "desc"}},
        "size": k
    }
    # print(body)
    response = es_client.search(index=index_name, body=body)
    for hit in response['hits']['hits']:
        print(hit['_score'], hit['_source']['desc'])


def inset_sample():
    file_path = r".\picture"
    files = os.listdir(file_path)
    for file in files:
        f_path = os.path.join(file_path, file)
        if not os.path.isfile(f_path):
            continue
        print(f_path)
        image_features = get_img_feature(f_path)
        insert(f_path, image_features)


def query_sample():
    source_pic = r"'.\cam1519.jpg"
    print("query:", source_pic)
    start_time = time.time()
    image_features = get_img_feature(source_pic)
    print("time :", time.time() - start_time)

    start_time = time.time()
    select(image_features, 3)
    print("time :", time.time() - start_time)


if __name__ == '__main__':
    # inset_sample()
    query_sample()
import os
from datetime import datetime
from elasticsearch import Elasticsearch

es_client = Elasticsearch("http://127.0.0.1:9200/")
text_index_name = "autodb_mail_text"


def insert_text(file_path, text):
    doc = {
        "file_path": file_path,
        "text_context": text,
        "timestamp": datetime.now(),
    }
    resp = es_client.index(index=text_index_name, document=doc)
    print(resp["result"])


def query_text(query_keys):
    must = []
    for q_k in query_keys:
        must.append({"match": {"text_context": q_k}})
    body = {
        "query": {
            "bool": {
                "must": must
            }
        },
        "sort": {"_score": {"order": "desc"}}
    }
    print(body)
    response = es_client.search(index=text_index_name, body=body)
    for hit in response['hits']['hits']:
        print(hit['_score'], hit['_source']['file_path'])


def inset_text_sample():
    file_path = r"\happy\Everyone"
    files = os.listdir(file_path)
    for file in files:
        f_path = os.path.join(file_path, file)
        if not os.path.isfile(f_path):
            continue
        print(f_path)
        with open(f_path, 'r', encoding='utf-8') as f:
            content = f.read()
            insert_text(f_path, content)


if __name__ == '__main__':
    # inset_text_sample()
    query_text(["key1", "key2", "key3"])