```python
import json
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.backend import sparse_multilabel_categorical_crossentropy
from bert4keras.tokenizers import Tokenizer
from bert4keras.layers import EfficientGlobalPointer as GlobalPointer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam, extend_with_exponential_moving_average
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open, to_array
from tqdm import tqdm
maxlen = 128
batch_size = 16
config_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
def load_data(filename):
"""加载数据
单条格式:{'text': text, 'spo_list': [(s, p, o)]}
"""
D = []
id2predicate = {}
predicate2id = {}
with open(filename, encoding='utf-8') as f:
for l in f:
l = json.loads(l)
entities_mapping = {}
for i in l["entities"]:
entities_mapping[i["id"]]=l['text'][i["start_offset"]:i["end_offset"]]
D.append({
'text': l['text'],
'spo_list': [(entities_mapping[spo['from_id']], spo['type'], entities_mapping[spo['to_id']])
for spo in l['relations']]
})
for spo in l["relations"]:
if spo['type'] not in predicate2id:
id2predicate[len(predicate2id)] = spo['type']
predicate2id[spo['type']] = len(predicate2id)
return D,id2predicate,predicate2id
all_data,id2predicate,predicate2id = load_data('untitled.txt')
train_data = all_data[:int(len(all_data)*0.8)]
valid_data = all_data[int(len(all_data)*0.8):]
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def search(pattern, sequence):
"""从sequence中寻找子串pattern
如果找到,返回第一个下标;否则返回-1。
"""
n = len(pattern)
for i in range(len(sequence)):
if sequence[i:i + n] == pattern:
return i
return -1
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids = [], []
batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
for is_end, d in self.sample(random):
token_ids, segment_ids = tokenizer.encode(d['text'], maxlen=maxlen)
spoes = set()
for s, p, o in d['spo_list']:
s = tokenizer.encode(s)[0][1:-1]
p = predicate2id[p]
o = tokenizer.encode(o)[0][1:-1]
sh = search(s, token_ids)
oh = search(o, token_ids)
if sh != -1 and oh != -1:
spoes.add((sh, sh + len(s) - 1, p, oh, oh + len(o) - 1))
entity_labels = [set() for _ in range(2)]
head_labels = [set() for _ in range(len(predicate2id))]
tail_labels = [set() for _ in range(len(predicate2id))]
for sh, st, p, oh, ot in spoes:
entity_labels[0].add((sh, st))
entity_labels[1].add((oh, ot))
head_labels[p].add((sh, oh))
tail_labels[p].add((st, ot))
for label in entity_labels + head_labels + tail_labels:
if not label:
label.add((0, 0))
entity_labels = sequence_padding([list(l) for l in entity_labels])
head_labels = sequence_padding([list(l) for l in head_labels])
tail_labels = sequence_padding([list(l) for l in tail_labels])
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_entity_labels.append(entity_labels)
batch_head_labels.append(head_labels)
batch_tail_labels.append(tail_labels)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_entity_labels = sequence_padding(
batch_entity_labels, seq_dims=2
)
batch_head_labels = sequence_padding(
batch_head_labels, seq_dims=2
)
batch_tail_labels = sequence_padding(
batch_tail_labels, seq_dims=2
)
yield [batch_token_ids, batch_segment_ids], [
batch_entity_labels, batch_head_labels, batch_tail_labels
]
batch_token_ids, batch_segment_ids = [], []
batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
def globalpointer_crossentropy(y_true, y_pred):
"""给GlobalPointer设计的交叉熵
"""
shape = K.shape(y_pred)
y_true = y_true[..., 0] * K.cast(shape[2], K.floatx()) + y_true[..., 1]
y_pred = K.reshape(y_pred, (shape[0], -1, K.prod(shape[2:])))
loss = sparse_multilabel_categorical_crossentropy(y_true, y_pred, True)
return K.mean(K.sum(loss, axis=1))
base = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
return_keras_model=False
)
entity_output = GlobalPointer(heads=2, head_size=64)(base.model.output)
head_output = GlobalPointer(
heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
tail_output = GlobalPointer(
heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
outputs = [entity_output, head_output, tail_output]
AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA')
optimizer = AdamEMA(learning_rate=1e-5)
model = keras.models.Model(base.model.inputs, outputs)
model.compile(loss=globalpointer_crossentropy, optimizer=optimizer)
def extract_spoes(text, threshold=0):
"""抽取输入text所包含的三元组
"""
tokens = tokenizer.tokenize(text, maxlen=maxlen)
mapping = tokenizer.rematch(text, tokens)
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
token_ids, segment_ids = to_array([token_ids], [segment_ids])
outputs = model.predict([token_ids, segment_ids])
outputs = [o[0] for o in outputs]
subjects, objects = set(), set()
outputs[0][:, [0, -1]] -= np.inf
outputs[0][:, :, [0, -1]] -= np.inf
for l, h, t in zip(*np.where(outputs[0] > threshold)):
if l == 0:
subjects.add((h, t))
else:
objects.add((h, t))
spoes = set()
for sh, st in subjects:
for oh, ot in objects:
p1s = np.where(outputs[1][:, sh, oh] > threshold)[0]
p2s = np.where(outputs[2][:, st, ot] > threshold)[0]
ps = set(p1s) & set(p2s)
for p in ps:
spoes.add((
text[mapping[sh][0]:mapping[st][-1] + 1], id2predicate[p],
text[mapping[oh][0]:mapping[ot][-1] + 1]
))
return list(spoes)
class SPO(tuple):
"""用来存三元组的类
表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
使得在判断两个三元组是否等价时容错性更好。
"""
def __init__(self, spo):
self.spox = (
tuple(tokenizer.tokenize(spo[0])),
spo[1],
tuple(tokenizer.tokenize(spo[2])),
)
def __hash__(self):
return self.spox.__hash__()
def __eq__(self, spo):
return self.spox == spo.spox
def evaluate(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 1e-10, 1e-10, 1e-10
f = open('dev_pred.json', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
R = set([SPO(spo) for spo in extract_spoes(d['text'])])
T = set([SPO(spo) for spo in d['spo_list']])
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description(
'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
)
s = json.dumps({
'text': d['text'],
'spo_list': list(T),
'spo_list_pred': list(R),
'new': list(R - T),
'lack': list(T - R),
},
ensure_ascii=False,
indent=4)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, epoch, logs=None):
optimizer.apply_ema_weights()
f1, precision, recall = evaluate(valid_data)
if f1 >= self.best_val_f1:
self.best_val_f1 = f1
model.save_weights('best_model.weights')
optimizer.reset_old_weights()
print(
'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
(f1, precision, recall, self.best_val_f1)
)
```
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
np_resource = np.dtype([("resource", np.ubyte, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
/opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated
np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.
WARNING:tensorflow:From /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:2403: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
```python
if __name__ == '__main__':
train_generator = data_generator(train_data, batch_size)
evaluator = Evaluator()
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=20,
callbacks=[evaluator]
)
else:
model.load_weights('best_model.weights')
```
```python
import json
word_list_10000 = json.load(open("word_list_10000.json","r"))
```
```python
```
0it [00:00, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-6edc72d74b13> in <module>
----> 1 evaluate(word_list_10000)
<ipython-input-1-f4de40003f6b> in evaluate(data)
225 pbar = tqdm()
226 for d in data:
--> 227 R = set([SPO(spo) for spo in extract_spoes(d['text'])])
228 T = set([SPO(spo) for spo in d['spo_list']])
229 X += len(R & T)
TypeError: string indices must be integers
```python
word_list_50000 = json.load(open("word_set_list_50000.jsonl","r"))
model.load_weights('best_model.weights')
def evaluate_data(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 1e-10, 1e-10, 1e-10
f = open('dev_pred_50000.jsonl', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
R = set([SPO(spo) for spo in extract_spoes(d)])
if len(R) < 1:
continue
T = set()
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description(
'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
)
s = json.dumps({
'text': d,
'spo_list': list(R),
},
ensure_ascii=False)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
evaluate_data(word_list_50000)
```
0it [00:00, ?it/s]
WARNING:tensorflow:From /root/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
f1: 0.00000, precision: 0.00000, recall: 1.00000: : 8521it [26:50, 5.29it/s]
(1.263823064770916e-14, 6.3191153238546206e-15, 1.0)
```python
word_list_50000 = json.load(open("word_set_list.json","r"))
model.load_weights('best_model.weights')
def evaluate_data(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 1e-10, 1e-10, 1e-10
f = open('dev_pred_all.jsonl', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
R = set([SPO(spo) for spo in extract_spoes(d)])
if len(R) < 1:
continue
T = set()
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description(
'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
)
s = json.dumps({
'text': d,
'spo_list': list(R),
},
ensure_ascii=False)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
evaluate_data(word_list_50000)
```
f1: 0.00000, precision: 0.00000, recall: 1.00000: : 83367it [4:16:07, 6.11it/s]
```python
model.load_weights('best_model.weights')
def evaluate_data(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 1e-10, 1e-10, 1e-10
f = open('dev_pred.jsonl', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
R = set([SPO(spo) for spo in extract_spoes(d)])
if len(R) < 1:
continue
T = set()
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description(
'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
)
s = json.dumps({
'text': d,
'spo_list': list(R),
},
ensure_ascii=False)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
evaluate_data(word_list_10000)
```
```python
train_generator = data_generator(train_data, batch_size)
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=20,
callbacks=[evaluator]
)
```
Epoch 1/20
125/125 [==============================] - 67s 538ms/step - loss: 1.6920 - efficient_global_pointer_1_loss: 0.5173 - efficient_global_pointer_2_loss: 0.6092 - efficient_global_pointer_3_loss: 0.5656
f1: 0.49655, precision: 0.48770, recall: 0.50572: : 500it [00:24, 20.01it/s]
f1: 0.49655, precision: 0.48770, recall: 0.50572, best f1: 0.49871
Epoch 2/20
125/125 [==============================] - 68s 541ms/step - loss: 1.5598 - efficient_global_pointer_1_loss: 0.4498 - efficient_global_pointer_2_loss: 0.5779 - efficient_global_pointer_3_loss: 0.5321
f1: 0.50217, precision: 0.49530, recall: 0.50923: : 500it [00:25, 19.56it/s]
f1: 0.50217, precision: 0.49530, recall: 0.50923, best f1: 0.50217
Epoch 3/20
125/125 [==============================] - 67s 540ms/step - loss: 1.4841 - efficient_global_pointer_1_loss: 0.4538 - efficient_global_pointer_2_loss: 0.5419 - efficient_global_pointer_3_loss: 0.4883
f1: 0.50766, precision: 0.50523, recall: 0.51011: : 500it [00:25, 19.76it/s]
f1: 0.50766, precision: 0.50523, recall: 0.51011, best f1: 0.50766
Epoch 4/20
125/125 [==============================] - 66s 532ms/step - loss: 1.3896 - efficient_global_pointer_1_loss: 0.4203 - efficient_global_pointer_2_loss: 0.5053 - efficient_global_pointer_3_loss: 0.4639
f1: 0.51189, precision: 0.51280, recall: 0.51099: : 500it [00:25, 19.70it/s]
f1: 0.51189, precision: 0.51280, recall: 0.51099, best f1: 0.51189
Epoch 5/20
125/125 [==============================] - 68s 543ms/step - loss: 1.2606 - efficient_global_pointer_1_loss: 0.3801 - efficient_global_pointer_2_loss: 0.4629 - efficient_global_pointer_3_loss: 0.4177
f1: 0.51218, precision: 0.51607, recall: 0.50836: : 500it [00:25, 19.84it/s]
f1: 0.51218, precision: 0.51607, recall: 0.50836, best f1: 0.51218
Epoch 6/20
125/125 [==============================] - 67s 538ms/step - loss: 1.1541 - efficient_global_pointer_1_loss: 0.3397 - efficient_global_pointer_2_loss: 0.4406 - efficient_global_pointer_3_loss: 0.3738
f1: 0.51337, precision: 0.52033, recall: 0.50660: : 500it [00:24, 20.07it/s]
f1: 0.51337, precision: 0.52033, recall: 0.50660, best f1: 0.51337
Epoch 7/20
125/125 [==============================] - 68s 541ms/step - loss: 1.0183 - efficient_global_pointer_1_loss: 0.2941 - efficient_global_pointer_2_loss: 0.3768 - efficient_global_pointer_3_loss: 0.3474
f1: 0.51544, precision: 0.52459, recall: 0.50660: : 500it [00:25, 19.86it/s]
f1: 0.51544, precision: 0.52459, recall: 0.50660, best f1: 0.51544
Epoch 8/20
125/125 [==============================] - 68s 544ms/step - loss: 1.0351 - efficient_global_pointer_1_loss: 0.3146 - efficient_global_pointer_2_loss: 0.3885 - efficient_global_pointer_3_loss: 0.3320
f1: 0.51959, precision: 0.53229, recall: 0.50748: : 500it [00:25, 19.69it/s]
f1: 0.51959, precision: 0.53229, recall: 0.50748, best f1: 0.51959
Epoch 9/20
125/125 [==============================] - 67s 538ms/step - loss: 0.9599 - efficient_global_pointer_1_loss: 0.2925 - efficient_global_pointer_2_loss: 0.3578 - efficient_global_pointer_3_loss: 0.3095
f1: 0.52022, precision: 0.53168, recall: 0.50923: : 500it [00:25, 19.90it/s]
f1: 0.52022, precision: 0.53168, recall: 0.50923, best f1: 0.52022
Epoch 10/20
125/125 [==============================] - 67s 533ms/step - loss: 0.9140 - efficient_global_pointer_1_loss: 0.2787 - efficient_global_pointer_2_loss: 0.3511 - efficient_global_pointer_3_loss: 0.2843
f1: 0.52299, precision: 0.53654, recall: 0.51011: : 500it [00:25, 19.65it/s]
f1: 0.52299, precision: 0.53654, recall: 0.51011, best f1: 0.52299
Epoch 11/20
125/125 [==============================] - 67s 533ms/step - loss: 0.8556 - efficient_global_pointer_1_loss: 0.2696 - efficient_global_pointer_2_loss: 0.3154 - efficient_global_pointer_3_loss: 0.2706
f1: 0.52385, precision: 0.53641, recall: 0.51187: : 500it [00:25, 19.71it/s]
f1: 0.52385, precision: 0.53641, recall: 0.51187, best f1: 0.52385
Epoch 12/20
125/125 [==============================] - 67s 533ms/step - loss: 0.7335 - efficient_global_pointer_1_loss: 0.2199 - efficient_global_pointer_2_loss: 0.2733 - efficient_global_pointer_3_loss: 0.2403
f1: 0.52351, precision: 0.53860, recall: 0.50923: : 500it [00:25, 19.82it/s]
f1: 0.52351, precision: 0.53860, recall: 0.50923, best f1: 0.52385
Epoch 13/20
125/125 [==============================] - 67s 538ms/step - loss: 0.7204 - efficient_global_pointer_1_loss: 0.2177 - efficient_global_pointer_2_loss: 0.2536 - efficient_global_pointer_3_loss: 0.2491
f1: 0.52407, precision: 0.54178, recall: 0.50748: : 500it [00:25, 19.85it/s]
f1: 0.52407, precision: 0.54178, recall: 0.50748, best f1: 0.52407
Epoch 14/20
125/125 [==============================] - 67s 537ms/step - loss: 0.6730 - efficient_global_pointer_1_loss: 0.1986 - efficient_global_pointer_2_loss: 0.2597 - efficient_global_pointer_3_loss: 0.2147
f1: 0.52465, precision: 0.54004, recall: 0.51011: : 500it [00:25, 19.68it/s]
f1: 0.52465, precision: 0.54004, recall: 0.51011, best f1: 0.52465
Epoch 15/20
125/125 [==============================] - 67s 539ms/step - loss: 0.6692 - efficient_global_pointer_1_loss: 0.2074 - efficient_global_pointer_2_loss: 0.2523 - efficient_global_pointer_3_loss: 0.2095
f1: 0.52655, precision: 0.54409, recall: 0.51011: : 500it [00:25, 19.71it/s]
f1: 0.52655, precision: 0.54409, recall: 0.51011, best f1: 0.52655
Epoch 16/20
80/125 [==================>...........] - ETA: 23s - loss: 0.6411 - efficient_global_pointer_1_loss: 0.1935 - efficient_global_pointer_2_loss: 0.2297 - efficient_global_pointer_3_loss: 0.2179
```python
```
```python
```