本代码主要用于对训练好的NER模型的接口进行测试
使用流程:
1、模型接口:curl -F 'query= 一句话' 192.168.1.106:5662/ner
把训练好的NER模型用flask框架部署,flask详解请见zhuanlan.zhihu.com/p/137655320
2、准备测试所用数据:example.dev.txt
NER标签有4个<name>、<position>、<organization>、<adress>,使用BMEO方式标注,可参照自己的数据标注格式对代码相应部分进行修改
3、测试代码test.py
#!/usr/bin/env python
# coding: utf-8
import re
import os
import json
from tqdm import tqdm
#读入标注好的NER数据,将每行句子和标签提取出来
def read_data(file_path):
sentence = []
label = []
sen = ''
tag = ''
with open(file_path, 'r') as f:
for line in f.readlines():
if(line != '\n'):
line = line.split(' ')
sen += line[0]
tag += line[1]
else:
sentence.append(sen)
label.append(tag)
sen = ''
tag = ''
return sentence, label
#用模型接口对句子进行预测,返回预测的标签
def predict(sent):
sen = sent.replace("'", ',')
#在这改不同的接口
res = os.popen("curl -F 'query= {sen}' 192.168.1.106:5662/ner").read()
res = json.loads(res)
predict_label = ['O']*len(sen)
for key in res['label'].keys():
#过滤掉不进行测试的标签
if(key == 'action'):
continue
if(res['label'][key]):
#print(res['label'])
for v in res['label'][key].values():
start = v[0][0]
end = v[0][1]
for i in range(start, end+1):
if(i == start):
predict_label[i] = 'B-'+str(key)
elif(i != start and i != end):
predict_label[i] = 'M-'+str(key)
elif(i == end):
predict_label[i] = 'E-'+str(key)
return predict_label
#获取实体:包括预测的全部实体和标注的全部实体
def split_entity(label_sequence):
entity_mark = dict()
entity_pointer = None
for index, label in enumerate(label_sequence):
#这里可以改是否只预测单个标签
if label.startswith('B'):
# if(label.split('-')[1] != 'position'):
# continue
category = label.split('-')[1]
entity_pointer = (index, category)
entity_mark.setdefault(entity_pointer, [label])
elif label.startswith('M'):
# if(label.split('-')[1] != 'position'):
# continue
if entity_pointer is None: continue
if entity_pointer[1] != label.split('-')[1]: continue
entity_mark[entity_pointer].append(label)
elif label.startswith('E'):
# if(label.split('-')[1] != 'position'):
# continue
if entity_pointer is None: continue
if entity_pointer[1] != label.split('-')[1]: continue
entity_mark[entity_pointer].append(label)
else:
entity_pointer = None
return entity_mark
#比较一句话的真实标签和预测标签,返回真实实体数量、预测实体数量、预测正确实体数量
def evaluate_one(real_label, predict_label):
real_entity_mark = split_entity(real_label)
predict_entity_mark = split_entity(predict_label)
true_entity_mark = dict()
key_set = real_entity_mark.keys() and predict_entity_mark.keys()
for key in key_set:
real_entity = real_entity_mark.get(key)
predict_entity = predict_entity_mark.get(key)
if real_entity != None and predict_entity != None:
if tuple(real_entity) == tuple(predict_entity):
true_entity_mark.setdefault(key, real_entity)
real_entity_num = len(real_entity_mark)
predict_entity_num = len(predict_entity_mark)
true_entity_num = len(true_entity_mark)
return real_entity_num, predict_entity_num, true_entity_num
#计算准确率和召回率
def evaluate(file_path):
sentence, label = read_data(file_path)
real_entity_all = 0
predict_entity_all = 0
true_entity_all = 0
for i in tqdm(range(len(sentence))):
predict_label = predict(sentence[i])
real_label = label[i].split('\n')
real_entity_num, predict_entity_num, true_entity_num = evaluate_one(real_label, predict_label)
real_entity_all += real_entity_num
predict_entity_all += predict_entity_num
true_entity_all += true_entity_num
#print("{:.3f}, {:.3f}, {:.3f}".format(real_entity_all,predict_entity_all,true_entity_all))
precision = true_entity_all / predict_entity_all
recall = true_entity_all / real_entity_all
f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1
if __name__ == '__main__':
#使用标注好的的NER数据集
file_path = 'example.dev.txt'
precision, recall, f1 = evaluate(file_path)
print("precision, recall, and f1:")
print("{:.3f}, {:.3f}, {:.3f}".format(precision,recall,f1))
4、将数据集和代码放在同一路径下
5、运行python test.py,输出precision、recall、f1:
Note:
1、代码中可以修改测试的数据所在路径
2、predict()函数中可以修改测试的接口,还可以修改相应的代码过滤掉不想测试的标签
3、split_entity()函数可以修改代码,用于只测试一个标签
参考文献: