构建Dataset数据集
使用的数据集为DuIE2.0,它是业界规模最大的中文关系抽取数据集,其schema在传统简单关系类型基础上添加了多元复杂关系类型,此外其构建语料来自百度百科、百度信息流及百度贴吧文本,全面覆盖书面化表达及口语化表达语料,能充分考察真实业务场景下的关系抽取能力。
文件的加载与分词
-
添加配置项
# config.py TRAIN_JSON_PATH = './data/input/duie/duie_train.json' TEST_JSON_PATH = './data/input/duie/duie_test.json' DEV_JSON_PATH = './data/input/duie/duie_dev.json' BERT_MODEL_NAME = 'bert-base-chinese' -
新建文件
# utils.py import torch.utils.data as data import pandas as pd import random from config import * import json from transformers import BertTokenizerFast -
加载关系表
def get_rel(): df = pd.read_csv(REL_PATH, names=['rel', 'id']) return df['rel'].tolist(), dict(df.values) id2rel, rel2id = get_rel() print(id2rel) # 因为list本身的位置就有id的特性 print(rel2id) exit()
['毕业院校', '嘉宾', '配音', '主题曲', '代言人', '所属专辑'.....
{'毕业院校': 0, '嘉宾': 1, '配音': 2, '主题曲': 3, '代言人': 4, .....
-
Dataset初始化
class Dataset(data.Dataset): def __init__(self, type='train'): # type类型为加载的哪个文件 super().__init__() _, self.rel2id = get_rel() # 加载文件 if type == 'train': file_path = TRAIN_JSON_PATH elif type == 'test': file_path = TEST_JSON_PATH elif type == 'dev': file_path = DEV_JSON_PATH with open(file_path) as f: self.lines = f.readlines() # 按行去读取文件,拿到训练集的长度==》lines # 加载bert self.tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME) def __len__(self): return len(self.lines) def __getitem__(self, index): line = self.lines[index] info = json.loads(line) tokenized = self.tokenizer(info['text'], return_offsets_mapping=True) # 第一个参数为要转换的文本,第二个参数是为了中英文混搭,使用偏移量记录词 info['input_ids'] = tokenized['input_ids'] # 追加给info info['offset_mapping'] = tokenized['offset_mapping'] print(info) exit() -
尝试加载数据集
if __name__ == '__main__': dataset = Dataset() loader = data.DataLoader(dataset) print(iter(loader).next()) # 指针往后挪动一位,取第一条他的数据
解析json数据并计算实体位置
解析基本信息
def parse_json(self, info):
text = info['text']
input_ids = info['input_ids']
dct = {
'text': text,
'input_ids': input_ids,
'offset_mapping': info['offset_mapping'],
'sub_head_ids': [],
'sub_tail_ids': [],
'triple_list': [],
'triple_id_list': []
}
for spo in info['spo_list']:
subject = spo['subject']
object = spo['object']['@value']
predicate = spo['predicate']
dct['triple_list'].append((subject, predicate, object))
# @todo
exit(dct)
return dct
计算实体位置
# 计算 subject 实体位置
tokenized = self.tokenizer(subject, add_special_tokens=False)
sub_token = tokenized['input_ids']
sub_pos_id = self.get_pos_id(input_ids, sub_token)
if not sub_pos_id:
continue
sub_head_id, sub_tail_id = sub_pos_id
# 计算 object 实体位置
tokenized = self.tokenizer(object, add_special_tokens=False)
obj_token = tokenized['input_ids']
obj_pos_id = self.get_pos_id(input_ids, obj_token)
if not obj_pos_id:
continue
obj_head_id, obj_tail_id = obj_pos_id
# 数据组装
dct['sub_head_ids'].append(sub_head_id)
dct['sub_tail_ids'].append(sub_tail_id)
dct['triple_id_list'].append((
[sub_head_id, sub_tail_id],
self.rel2id[predicate],
[obj_head_id, obj_tail_id],
))
位置计算方法
source为原始的文本,elem为当前的subject的token。遍历原始的id,滑动窗口,每次找一段儿去校对,每次找elem对应长度的值,就算超过了length那也匹配不上。
def get_pos_id(self, source, elem):
for head_id in range(len(source)):
tail_id = head_id + len(elem)
if source[head_id:tail_id] == elem:
return head_id, tail_id - 1
完整实现:
def parse_json(self, info):
text = info['text']
input_ids = info['input_ids']
dct = {
'text': text,
'input_ids': input_ids,
'offset_mapping': info['offset_mapping'],
'sub_head_ids': [],
'sub_tail_ids': [],
'triple_list': [],
'triple_id_list': []
}
for spo in info['spo_list']:
subject = spo['subject']
object = spo['object']['@value']
predicate = spo['predicate']
dct['triple_list'].append((subject, predicate, object))
# 计算 subject 实体位置
tokenized = self.tokenizer(subject, add_special_tokens=False)
sub_token = tokenized['input_ids']
sub_pos_id = self.get_pos_id(input_ids, sub_token)
if not sub_pos_id:
continue
sub_head_id, sub_tail_id = sub_pos_id
# 计算 object 实体位置
tokenized = self.tokenizer(object, add_special_tokens=False)
obj_token = tokenized['input_ids']
obj_pos_id = self.get_pos_id(input_ids, obj_token)
if not obj_pos_id:
continue
obj_head_id, obj_tail_id = obj_pos_id
# 数据组装
dct['sub_head_ids'].append(sub_head_id)
dct['sub_tail_ids'].append(sub_tail_id)
dct['triple_id_list'].append((
[sub_head_id, sub_tail_id],
self.rel2id[predicate],
[obj_head_id, obj_tail_id],
))
exit(dct)
return dct
def get_pos_id(self, source, elem):
for head_id in range(len(source)):
tail_id = head_id + len(elem)
if source[head_id:tail_id] == elem:
return head_id, tail_id - 1
{'text': '《邪少兵王》是冰火未央写的网络小说连载于旗峰天下', 'input_ids': [101, 517, 6932, 2208, 1070, 4374, 518, 3221, 1102, 4125, 3313, 1925, 1091, 4638, 5381, 5317, 2207, 6432, 6825, 6770, 754, 3186, 2292, 1921, 678, 102], 'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (0, 0)], 'sub_head_ids': [2], 'sub_tail_ids': [5], 'triple_list': [('邪少兵王', '作者', '冰火未央')], 'triple_id_list': [([2, 5], 7, [8, 11])]}