import time
import torch
import clip
import os
from PIL import Image
from elasticsearch import Elasticsearch
from datetime import datetime
es_client = Elasticsearch("http://127.0.0.1:9200/")
index_name = "autodb_image_demorn50x64"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model, preprocess = clip.load("RN50x64", device=device)
def get_img_feature(image_path):
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
return image_features
def get_text_feature(text_msg):
text = clip.tokenize([text_msg]).to(device)
with torch.no_grad():
text_features = model.encode_text(text)
return text_features
def insert(desc, dense):
numpy_array = dense.numpy()
list = numpy_array[0].tolist()
doc = {
"desc": desc,
"img_vector": list,
"timestamp": datetime.now(),
}
resp = es_client.index(index=index_name, document=doc)
print(resp["result"])
def select(dense, k):
numpy_array = dense.numpy()
list = numpy_array[0].tolist()
body = {
"query": {
"knn": {
"field": "img_vector",
"query_vector": list,
"k": "10"
}
},
"sort": {"_score": {"order": "desc"}},
"size": k
}
response = es_client.search(index=index_name, body=body)
for hit in response['hits']['hits']:
print(hit['_score'], hit['_source']['desc'])
def inset_sample():
file_path = r".\picture"
files = os.listdir(file_path)
for file in files:
f_path = os.path.join(file_path, file)
if not os.path.isfile(f_path):
continue
print(f_path)
image_features = get_img_feature(f_path)
insert(f_path, image_features)
def query_sample():
source_pic = r"'.\cam1519.jpg"
print("query:", source_pic)
start_time = time.time()
image_features = get_img_feature(source_pic)
print("time :", time.time() - start_time)
start_time = time.time()
select(image_features, 3)
print("time :", time.time() - start_time)
if __name__ == '__main__':
query_sample()
import os
from datetime import datetime
from elasticsearch import Elasticsearch
es_client = Elasticsearch("http://127.0.0.1:9200/")
text_index_name = "autodb_mail_text"
def insert_text(file_path, text):
doc = {
"file_path": file_path,
"text_context": text,
"timestamp": datetime.now(),
}
resp = es_client.index(index=text_index_name, document=doc)
print(resp["result"])
def query_text(query_keys):
must = []
for q_k in query_keys:
must.append({"match": {"text_context": q_k}})
body = {
"query": {
"bool": {
"must": must
}
},
"sort": {"_score": {"order": "desc"}}
}
print(body)
response = es_client.search(index=text_index_name, body=body)
for hit in response['hits']['hits']:
print(hit['_score'], hit['_source']['file_path'])
def inset_text_sample():
file_path = r"\happy\Everyone"
files = os.listdir(file_path)
for file in files:
f_path = os.path.join(file_path, file)
if not os.path.isfile(f_path):
continue
print(f_path)
with open(f_path, 'r', encoding='utf-8') as f:
content = f.read()
insert_text(f_path, content)
if __name__ == '__main__':
query_text(["key1", "key2", "key3"])