NER模型接口准确率测试代码

2,813 阅读2分钟

本代码主要用于对训练好的NER模型的接口进行测试

使用流程:

1、模型接口:curl -F 'query= 一句话' 192.168.1.106:5662/ner

把训练好的NER模型用flask框架部署,flask详解请见zhuanlan.zhihu.com/p/137655320

2、准备测试所用数据:example.dev.txt

1648806307.png

NER标签有4个<name>、<position>、<organization>、<adress>,使用BMEO方式标注,可参照自己的数据标注格式对代码相应部分进行修改

3、测试代码test.py

#!/usr/bin/env python
# coding: utf-8

import re
import os
import json
from tqdm import tqdm

#读入标注好的NER数据,将每行句子和标签提取出来
def read_data(file_path):
    sentence = []
    label = []
    sen = ''
    tag = ''
    with open(file_path, 'r') as f:
        for line in f.readlines():
            if(line != '\n'):
                line = line.split(' ')
                sen += line[0]
                tag += line[1]
            else:
                sentence.append(sen)
                label.append(tag)
                sen = ''
                tag = ''
    return sentence, label

#用模型接口对句子进行预测,返回预测的标签
def predict(sent):
    sen = sent.replace("'", ',')
    #在这改不同的接口
    res = os.popen("curl -F 'query= {sen}' 192.168.1.106:5662/ner").read()
    res = json.loads(res)
    predict_label = ['O']*len(sen)
    for key in res['label'].keys():
        #过滤掉不进行测试的标签
        if(key == 'action'):
            continue
        if(res['label'][key]):
            #print(res['label'])
            for v in res['label'][key].values():
                start = v[0][0]
                end = v[0][1]
                for i in range(start, end+1):
                    if(i == start):
                        predict_label[i] = 'B-'+str(key)
                    elif(i != start and i != end):
                        predict_label[i] = 'M-'+str(key)
                    elif(i == end):
                        predict_label[i] = 'E-'+str(key)
    return predict_label

#获取实体:包括预测的全部实体和标注的全部实体
def split_entity(label_sequence):
    entity_mark = dict()
    entity_pointer = None
    for index, label in enumerate(label_sequence):
        #这里可以改是否只预测单个标签
        if label.startswith('B'):
#             if(label.split('-')[1] != 'position'):
#                 continue
            category = label.split('-')[1]
            entity_pointer = (index, category)
            entity_mark.setdefault(entity_pointer, [label])
        elif label.startswith('M'):
#             if(label.split('-')[1] != 'position'):
#                 continue
            if entity_pointer is None: continue
            if entity_pointer[1] != label.split('-')[1]: continue
            entity_mark[entity_pointer].append(label)
        elif label.startswith('E'):
#             if(label.split('-')[1] != 'position'):
#                 continue
            if entity_pointer is None: continue
            if entity_pointer[1] != label.split('-')[1]: continue
            entity_mark[entity_pointer].append(label)
        else:
            entity_pointer = None
    return entity_mark

#比较一句话的真实标签和预测标签,返回真实实体数量、预测实体数量、预测正确实体数量
def evaluate_one(real_label, predict_label):
    real_entity_mark = split_entity(real_label)
    predict_entity_mark = split_entity(predict_label)
    true_entity_mark = dict()
    key_set = real_entity_mark.keys() and predict_entity_mark.keys()
    for key in key_set:
        real_entity = real_entity_mark.get(key)
        predict_entity = predict_entity_mark.get(key)
        if real_entity != None and predict_entity != None:
            if tuple(real_entity) == tuple(predict_entity):
                true_entity_mark.setdefault(key, real_entity)
    real_entity_num = len(real_entity_mark)
    predict_entity_num = len(predict_entity_mark)
    true_entity_num = len(true_entity_mark)

    return real_entity_num, predict_entity_num, true_entity_num

#计算准确率和召回率
def evaluate(file_path):
    sentence, label = read_data(file_path)
    real_entity_all = 0
    predict_entity_all = 0
    true_entity_all = 0 
    for i in tqdm(range(len(sentence))):
        predict_label = predict(sentence[i])
        real_label = label[i].split('\n')
        real_entity_num, predict_entity_num, true_entity_num = evaluate_one(real_label, predict_label)
        real_entity_all += real_entity_num
        predict_entity_all += predict_entity_num
        true_entity_all += true_entity_num
        #print("{:.3f}, {:.3f}, {:.3f}".format(real_entity_all,predict_entity_all,true_entity_all))
    precision = true_entity_all / predict_entity_all
    recall = true_entity_all / real_entity_all
    f1 = 2 * precision * recall / (precision + recall)
    return precision, recall, f1

if __name__ == '__main__':
    #使用标注好的的NER数据集
    file_path = 'example.dev.txt'
    precision, recall, f1 = evaluate(file_path)
    print("precision, recall, and f1:")
    print("{:.3f}, {:.3f}, {:.3f}".format(precision,recall,f1))

4、将数据集和代码放在同一路径下

5、运行python test.py,输出precision、recall、f1:

Note:

1、代码中可以修改测试的数据所在路径

2、predict()函数中可以修改测试的接口,还可以修改相应的代码过滤掉不想测试的标签

3、split_entity()函数可以修改代码,用于只测试一个标签

参考文献:

zhuanlan.zhihu.com/p/56582082

zhuanlan.zhihu.com/p/137655320