文本分类模型合集-详细注解--tf/pytorch双版本

·  阅读 464

Text-Similarity

代码仓库地址:github.com/DengBoCong/…

文本相似度(匹配)计算,提供Baseline、训练、推理、指标分析...代码包含TensorFlow/Pytorch双版本,特别适合新手上手。

Overview

  • Dataset: 中文/English 语料, ☞ 点这里
  • Paper: 相关论文详解, ☞ 点这里
  • The implemented method is as follows:
    • TF-IDF
    • BM25
    • LSH
    • SIF/uSIF
    • FastText
    • RNN Base (Siamese RNN, Stack RNN)
    • CNN Base (Fast Text, Text CNN, Char CNN, VDCNN)
    • Bert Base
    • Albert
    • NEZHA
    • RoBERTa
    • SimCSE
    • Poly-Encoder
    • ColBERT
    • RE2(Simple-Effective-Text-Matching)

Usages

1:examples目录下有不同模型对应的 preprocess/train/evalute代码,可自行修改
2:如下示例从examples中引入actuator方法,准备好对应的模型配置文件即可执行
3:examples目录下的inference.py为训练好的模型推理代码
复制代码

TF-IDF

# Example
# Sklearn version
from examples.run_tfidf_sklearn import actuator
actuator("./corpus/chinese/breeno/train.tsv", query1="12 23 4160 276", query2="29 23 169 1495")

# Custom version
from examples.run_tfidf import actuator
actuator("./corpus/chinese/breeno/train.tsv", query1="12 23 4160 276", query2="29 23 169 1495")

# 工具调用
from sim.tf_idf import TFIdf

tokens_list = ["这是 一个 什么 样 的 工具", "..."]
query = ["非常 好用 的 工具"]

tf_idf = TFIdf(tokens_list, split=" ")
print(tf_idf.get_score(query, 0))  # score
print(tf_idf.get_score_list(query, 10))  # [(index, score), ...]
print(tf_idf.weight())  # list or numpy array
复制代码

BM25

# Example
from examples.run_bm25 import actuator
actuator("./corpus/chinese/breeno/train.tsv", query1="12 23 4160 276", query2="29 23 169 1495")

# 工具调用
from sim.bm25 import BM25

tokens_list = ["这是 一个 什么 样 的 工具", "..."]
query = ["非常 好用 的 工具"]

bm25 = BM25(tokens_list, split=" ")
print(bm25.get_score(query, 0))  # score
print(bm25.get_score_list(query, 10))  # [(index, score), ...]
print(bm25.weight())  # list or numpy array
复制代码

LSH

from sim.lsh import E2LSH
from sim.lsh import MinHash

e2lsh = E2LSH()
min_hash = MinHash()

candidates = [[3.6216, 8.6661, -2.8073, -0.44699, 0], ...]
query = [-2.7769, -5.6967, 5.9179, 0.37671, 1]
print(e2lsh.search(candidates, query))  # index in candidates
print(min_hash.search(candidates, query))  # index in candidates
复制代码

SIF

sentences = [["token1", "token2", "..."], ...]
vector = [[[1, 1, 1], [2, 2, 2], [...]], ...]
from sim.sif_usif import SIF
from sim.sif_usif import uSIF

sif = SIF(n_components=5, component_type="svd")
sif.fit(tokens_list=sentences, vector_list=vector)

usif = uSIF(n_components=5, n=1, component_type="svd")
usif.fit(tokens_list=sentences, vector_list=vector)
复制代码

FastText

# TensorFlow version
from examples.tensorflow.run_fast_text import actuator
actuator(execute_type="train", model_type="bert", model_dir="./data/chinese_wwm_L-12_H-768_A-12")

# Pytorch version
from examples.pytorch.run_fast_text import actuator
actuator(execute_type="train", model_type="bert", model_dir="./data/chinese_wwm_pytorch")
复制代码

RNN Base

# TensorFlow version
from examples.tensorflow.run_siamese_rnn import actuator
actuator("./data/config/siamse_rnn.json", execute_type="train")

# Pytorch version
from examples.pytorch.run_siamese_rnn import actuator
actuator("./data/config/siamse_rnn.json", execute_type="train")
复制代码

CNN Base

# TensorFlow version
from examples.tensorflow.run_cnn_base import actuator
actuator(execute_type="train", model_type="bert", model_dir="./data/chinese_wwm_L-12_H-768_A-12")

# Pytorch version
from examples.pytorch.run_cnn_base import actuator
actuator(execute_type="train", model_type="bert", model_dir="./data/chinese_wwm_pytorch")
复制代码

Bert Base

# TensorFlow version
from examples.tensorflow.run_basic_bert import actuator
actuator(model_dir="./data/chinese_wwm_L-12_H-768_A-12", execute_type="train")

# Pytorch version
from examples.pytorch.run_basic_bert import actuator
actuator(model_dir="./data/chinese_wwm_pytorch", execute_type="train")
复制代码

Albert

# TensorFlow version
from examples.tensorflow.run_albert import actuator
actuator(model_dir="./data/albert_small_zh_google", execute_type="train")

# Pytorch version
from examples.pytorch.run_albert import actuator
actuator(model_dir="./data/albert_chinese_small", execute_type="train")
复制代码

NEZHA

# TensorFlow version
from examples.tensorflow.run_nezha import actuator
actuator(model_dir="./data/NEZHA-Base-WWM", execute_type="train")

# Pytorch version
from examples.pytorch.run_nezha import actuator
actuator(model_dir="./data/nezha-base-wwm", execute_type="train")
复制代码

RoBERTa

# TensorFlow version
from examples.tensorflow.run_basic_bert import actuator
actuator(model_dir="./data/chinese_roberta_L-6_H-384_A-12", execute_type="train")

# Pytorch version
from examples.pytorch.run_basic_bert import actuator
actuator(model_dir="./data/chinese-roberta-wwm-ext", execute_type="train")
复制代码

SimCSE

# TensorFlow version
from examples.tensorflow.run_simcse import actuator
actuator(model_dir="./data/chinese_wwm_L-12_H-768_A-12", execute_type="train", model_type="bert")

# Pytorch version
from examples.pytorch.run_simcse import actuator
actuator(model_dir="./data/chinese_wwm_pytorch", execute_type="train", model_type="bert")
复制代码

Poly-Encoder

# TensorFlow version
from examples.tensorflow.run_poly_encoder import actuator
actuator(model_dir="./data/chinese_wwm_L-12_H-768_A-12", execute_type="train", model_type="bert")

# Pytorch version
from examples.pytorch.run_poly_encoder import actuator
actuator(model_dir="./data/chinese_wwm_pytorch", execute_type="train", model_type="bert")
复制代码

ColBERT

# TensorFlow version
from examples.tensorflow.run_colbert import actuator
actuator(model_dir="./data/chinese_wwm_L-12_H-768_A-12", execute_type="train", model_type="bert")

# Pytorch version
from examples.pytorch.run_colbert import actuator
actuator(model_dir="./data/chinese_wwm_pytorch", execute_type="train", model_type="bert")
复制代码

RE2

# TensorFlow version
from examples.tensorflow.run_re2 import actuator
actuator("./data/config/re2.json", execute_type="train")

# Pytorch version
from examples.pytorch.run_re2 import actuator
actuator("./data/config/re2.json", execute_type="train")
复制代码

Cite

@misc{text-similarity,
  title={text-similarity},
  author={Bocong Deng},
  year={2021},
  howpublished={\url{https://github.com/DengBoCong/text-similarity}},
}
复制代码

Reference

分类:
人工智能
标签:
分类:
人工智能
标签:
收藏成功!
已添加到「」, 点击更改