fasttext milvus 构建多段式文本序列相似度搜索引擎

748 阅读2分钟

重启之战 我在一台windows系统的笔记本上通过docker安装milvus 1.1.1版本 第一步安装milvus

docker run -d --name milvus_cpu_1.1.1 --platform linux/amd64 \
-p 19530:19530 \
-p 19121:19121 \
milvusdb/milvus:1.1.1-cpu-d061621-330cc6

安装好milvus之后,我们先学习一下milvus的官方案例

# create a vector collection,
# insert 10 vectors,
# and execute a vector similarity search.

import random

from milvus import Milvus, IndexType, MetricType, Status

# Milvus server IP address and port.
# You may need to change _HOST and _PORT accordingly.
_HOST = '127.0.0.1'
_PORT = '19530'  # default value
# _PORT = '19121'  # default http value

# Vector parameters
_DIM = 8  # dimension of vector

_INDEX_FILE_SIZE = 32  # max file size of stored index


def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection_'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.list_collections()

    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # 10000 vectors with 128 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    print(vectors)
    # You can also use numpy to generate random vectors:
    #   vectors = np.random.rand(10000, _DIM).astype(np.float32)

    # Insert vectors into demo_collection, return status and vectors id list    print(vectors.shape)
    status, ids = milvus.insert(collection_name=collection_name, records=vectors)
    if not status.OK():
        print("Insert failed: {}".format(status))

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])
    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)

    # present collection statistics info
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # Obtain raw vectors by providing vector ids
    status, result_vectors = milvus.get_entity_by_id(collection_name, ids[:10])

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {
        "nprobe": 16
    }

    print("Searching ... ")

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
    }

    status, results = milvus.search(**param)
    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

        # print results
        print(results)
    else:
        print("Search failed. ", status)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
if __name__ == '__main__':
    main()

改成读取自己的数据集 用fasttext训练一个文本模型 完整代码

# -*- coding: utf-8 -*-
import os

import fasttext
import jieba
import numpy as np
import pandas as pd
from milvus import Milvus, MetricType

base_path = os.path.dirname(os.path.abspath(__file__))
database_path = os.path.dirname(base_path)


# 加载 jieba 分词词典


def get_data():
    with open("finance_news_cut.txt", "w", encoding='utf-8') as f:
        for sentence_obj in pd.read_csv('./data/kol_profile_union.csv').values.tolist():
            one = []
            for sentence_obj_one in sentence_obj[2:]:
                if isinstance(sentence_obj_one, str):
                    one.append(sentence_obj_one)
            sentence = " ".join(one)
            if sentence.isalnum():
                continue

            seg_sentence = jieba.cut(sentence.replace("\t", " ").replace("\n", " "))

            outline = " ".join(seg_sentence)
            outline = outline + " "

            f.write(outline)
            f.flush()


def train_model():
    model = fasttext.train_unsupervised('finance_news_cut.txt', )
    model.save_model("news_fasttext.model.bin")


def get_word_vector(word, model):
    word_vector = model.get_word_vector(word)

    return word_vector


def cos_sim(vector_a, vector_b):
    """
计算两个向量之间的余弦相似度
:param vector_a: 向量 a
:param vector_b: 向量 b
:return: sim
"""
    vector_a = np.mat(vector_a)
    vector_b = np.mat(vector_b)
    num = float(vector_a * vector_b.T)
    denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
    cos = num / denom
    sim = 0.5 + 0.5 * cos
    return sim


if __name__ == "__main__":
    # get_data()
    # train_model()
    all_out_line = []
    for sentence_obj in pd.read_csv('./data/kol_profile_union.csv').values.tolist():
        one = []
        for sentence_obj_one in sentence_obj[2:]:
            if isinstance(sentence_obj_one, str):
                one.append(sentence_obj_one)
        sentence = " ".join(one)
        if sentence.isalnum():
            continue

        seg_sentence = jieba.lcut(sentence.replace("\t", " ").replace("\n", " "))

        outline = " ".join(seg_sentence)
        if len(seg_sentence) > 10:
            all_out_line.append(outline)
    all_out_line = list(set(all_out_line))
    out_result = []
    model = fasttext.load_model('news_fasttext/news_fasttext.model.bin')
    collection_name = "kol_sentence_vector"
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    _INDEX_FILE_SIZE = 32
    _DIM = 100  # dimension of vector

    milvus = Milvus(_HOST, _PORT)
    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)
    batch = []
    from tqdm import tqdm

    for i in tqdm(all_out_line):
        a = model.get_sentence_vector(i)
        # print(a)
        # print(a.shape)
        batch.append(a)
        if len(batch) == 10:
            insert_batch = np.array(batch)
            try:
                status, ids = milvus.insert(collection_name=collection_name, records=insert_batch)
                print(status)
            except:
                print(a)
                continue
            batch = []