三元组抽取任务,基于GlobalPointer的仿TPLinker设计 for doccano数据集设计

237 阅读5分钟
```python
#! -*- coding:utf-8 -*-
# 三元组抽取任务,基于GlobalPointer的仿TPLinker设计
# 文章介绍:https://kexue.fm/archives/8888
# 数据集:http://ai.baidu.com/broad/download?dataset=sked
# 最优f1=0.827
# 说明:由于使用了EMA,需要跑足够多的步数(5000步以上)才生效,如果
#      你的数据总量比较少,那么请务必跑足够多的epoch数,或者去掉EMA。

import json
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.backend import sparse_multilabel_categorical_crossentropy
from bert4keras.tokenizers import Tokenizer
from bert4keras.layers import EfficientGlobalPointer as GlobalPointer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam, extend_with_exponential_moving_average
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open, to_array
from tqdm import tqdm

maxlen = 128
batch_size = 16
config_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'


def load_data(filename):
    """加载数据
    单条格式:{'text': text, 'spo_list': [(s, p, o)]}
    """
    D = []
    id2predicate = {}
    predicate2id = {}
    with open(filename, encoding='utf-8') as f:
        for l in f:
            l = json.loads(l)
            entities_mapping = {}
            for i in l["entities"]:
                entities_mapping[i["id"]]=l['text'][i["start_offset"]:i["end_offset"]]
            D.append({
                'text': l['text'],
                'spo_list': [(entities_mapping[spo['from_id']], spo['type'], entities_mapping[spo['to_id']])
                             for spo in l['relations']]
            })
            for spo in l["relations"]:
                if spo['type'] not in predicate2id:
                    id2predicate[len(predicate2id)] = spo['type']
                    predicate2id[spo['type']] = len(predicate2id)

    return D,id2predicate,predicate2id


# 加载数据集
all_data,id2predicate,predicate2id = load_data('untitled.txt')
train_data = all_data[:int(len(all_data)*0.8)]
valid_data = all_data[int(len(all_data)*0.8):]
# predicate2id, id2predicate = {}, {}

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)


def search(pattern, sequence):
    """从sequence中寻找子串pattern
    如果找到,返回第一个下标;否则返回-1。
    """
    n = len(pattern)
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            return i
    return -1


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids = [], []
        batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
        for is_end, d in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(d['text'], maxlen=maxlen)
            # 整理三元组 {(s, o, p)}
            spoes = set()
            for s, p, o in d['spo_list']:
                s = tokenizer.encode(s)[0][1:-1]
                p = predicate2id[p]
                o = tokenizer.encode(o)[0][1:-1]
                sh = search(s, token_ids)
                oh = search(o, token_ids)
                if sh != -1 and oh != -1:
                    spoes.add((sh, sh + len(s) - 1, p, oh, oh + len(o) - 1))
            # 构建标签
            entity_labels = [set() for _ in range(2)]
            head_labels = [set() for _ in range(len(predicate2id))]
            tail_labels = [set() for _ in range(len(predicate2id))]
            for sh, st, p, oh, ot in spoes:
                entity_labels[0].add((sh, st))
                entity_labels[1].add((oh, ot))
                head_labels[p].add((sh, oh))
                tail_labels[p].add((st, ot))
            for label in entity_labels + head_labels + tail_labels:
                if not label:  # 至少要有一个标签
                    label.add((0, 0))  # 如果没有则用0填充
            entity_labels = sequence_padding([list(l) for l in entity_labels])
            head_labels = sequence_padding([list(l) for l in head_labels])
            tail_labels = sequence_padding([list(l) for l in tail_labels])
            # 构建batch
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_entity_labels.append(entity_labels)
            batch_head_labels.append(head_labels)
            batch_tail_labels.append(tail_labels)
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_entity_labels = sequence_padding(
                    batch_entity_labels, seq_dims=2
                )
                batch_head_labels = sequence_padding(
                    batch_head_labels, seq_dims=2
                )
                batch_tail_labels = sequence_padding(
                    batch_tail_labels, seq_dims=2
                )
                yield [batch_token_ids, batch_segment_ids], [
                    batch_entity_labels, batch_head_labels, batch_tail_labels
                ]
                batch_token_ids, batch_segment_ids = [], []
                batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []


def globalpointer_crossentropy(y_true, y_pred):
    """给GlobalPointer设计的交叉熵
    """
    shape = K.shape(y_pred)
    y_true = y_true[..., 0] * K.cast(shape[2], K.floatx()) + y_true[..., 1]
    y_pred = K.reshape(y_pred, (shape[0], -1, K.prod(shape[2:])))
    loss = sparse_multilabel_categorical_crossentropy(y_true, y_pred, True)
    return K.mean(K.sum(loss, axis=1))


# 加载预训练模型
base = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    return_keras_model=False
)

# 预测结果
entity_output = GlobalPointer(heads=2, head_size=64)(base.model.output)
head_output = GlobalPointer(
    heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
tail_output = GlobalPointer(
    heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
outputs = [entity_output, head_output, tail_output]

# 构建模型
AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA')
optimizer = AdamEMA(learning_rate=1e-5)
model = keras.models.Model(base.model.inputs, outputs)
model.compile(loss=globalpointer_crossentropy, optimizer=optimizer)
# model.summary()


def extract_spoes(text, threshold=0):
    """抽取输入text所包含的三元组
    """
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
    token_ids, segment_ids = to_array([token_ids], [segment_ids])
    outputs = model.predict([token_ids, segment_ids])
    outputs = [o[0] for o in outputs]
    # 抽取subject和object
    subjects, objects = set(), set()
    outputs[0][:, [0, -1]] -= np.inf
    outputs[0][:, :, [0, -1]] -= np.inf
    for l, h, t in zip(*np.where(outputs[0] > threshold)):
        if l == 0:
            subjects.add((h, t))
        else:
            objects.add((h, t))
    # 识别对应的predicate
    spoes = set()
    for sh, st in subjects:
        for oh, ot in objects:
            p1s = np.where(outputs[1][:, sh, oh] > threshold)[0]
            p2s = np.where(outputs[2][:, st, ot] > threshold)[0]
            ps = set(p1s) & set(p2s)
            for p in ps:
                spoes.add((
                    text[mapping[sh][0]:mapping[st][-1] + 1], id2predicate[p],
                    text[mapping[oh][0]:mapping[ot][-1] + 1]
                ))
    return list(spoes)


class SPO(tuple):
    """用来存三元组的类
    表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
    使得在判断两个三元组是否等价时容错性更好。
    """
    def __init__(self, spo):
        self.spox = (
            tuple(tokenizer.tokenize(spo[0])),
            spo[1],
            tuple(tokenizer.tokenize(spo[2])),
        )

    def __hash__(self):
        return self.spox.__hash__()

    def __eq__(self, spo):
        return self.spox == spo.spox


def evaluate(data):
    """评估函数,计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred.json', 'w', encoding='utf-8')
    pbar = tqdm()
    for d in data:
        R = set([SPO(spo) for spo in extract_spoes(d['text'])])
        T = set([SPO(spo) for spo in d['spo_list']])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description(
            'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
        )
        s = json.dumps({
            'text': d['text'],
            'spo_list': list(T),
            'spo_list_pred': list(R),
            'new': list(R - T),
            'lack': list(T - R),
        },
                       ensure_ascii=False,
                       indent=4)
        f.write(s + '\n')
    pbar.close()
    f.close()
    return f1, precision, recall


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_f1 = 0.

    def on_epoch_end(self, epoch, logs=None):
        optimizer.apply_ema_weights()
        f1, precision, recall = evaluate(valid_data)
        if f1 >= self.best_val_f1:
            self.best_val_f1 = f1
            model.save_weights('best_model.weights')
        optimizer.reset_old_weights()
        print(
            'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
            (f1, precision, recall, self.best_val_f1)
        )

```

    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint8 = np.dtype([("qint8", np.int8, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint16 = np.dtype([("qint16", np.int16, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint32 = np.dtype([("qint32", np.int32, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      np_resource = np.dtype([("resource", np.ubyte, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint8 = np.dtype([("qint8", np.int8, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint16 = np.dtype([("qint16", np.int16, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      _np_qint32 = np.dtype([("qint32", np.int32, 1)])
    /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
      np_resource = np.dtype([("resource", np.ubyte, 1)])
    Using TensorFlow backend.


    WARNING:tensorflow:From /opt/conda/envs/tf1.14/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:2403: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
    Instructions for updating:
    Use tf.where in 2.0, which has the same broadcast rule as np.where



```python


if __name__ == '__main__':

    train_generator = data_generator(train_data, batch_size)
    evaluator = Evaluator()

    model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=20,
        callbacks=[evaluator]
    )

else:

    model.load_weights('best_model.weights')
```


```python
import json 
word_list_10000 = json.load(open("word_list_10000.json","r"))
```


```python

```

    0it [00:00, ?it/s]


    ---------------------------------------------------------------------------

    TypeError                                 Traceback (most recent call last)

    <ipython-input-3-6edc72d74b13> in <module>
    ----> 1 evaluate(word_list_10000)
    

    <ipython-input-1-f4de40003f6b> in evaluate(data)
        225     pbar = tqdm()
        226     for d in data:
    --> 227         R = set([SPO(spo) for spo in extract_spoes(d['text'])])
        228         T = set([SPO(spo) for spo in d['spo_list']])
        229         X += len(R & T)


    TypeError: string indices must be integers



```python
word_list_50000 = json.load(open("word_set_list_50000.jsonl","r"))
model.load_weights('best_model.weights')
def evaluate_data(data):
    """评估函数,计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred_50000.jsonl', 'w', encoding='utf-8')
    pbar = tqdm()
    for d in data:
        R = set([SPO(spo) for spo in extract_spoes(d)])
        if len(R) < 1:
            continue
        T = set()
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description(
            'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
        )
        s = json.dumps({
            'text': d,
            'spo_list': list(R),
        },
                       ensure_ascii=False)
        f.write(s + '\n')
    pbar.close()
    f.close()
    return f1, precision, recall
evaluate_data(word_list_50000)
```

    0it [00:00, ?it/s]

    WARNING:tensorflow:From /root/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
    


    f1: 0.00000, precision: 0.00000, recall: 1.00000: : 8521it [26:50,  5.29it/s]





    (1.263823064770916e-14, 6.3191153238546206e-15, 1.0)




```python
word_list_50000 = json.load(open("word_set_list.json","r"))
model.load_weights('best_model.weights')
def evaluate_data(data):
    """评估函数,计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred_all.jsonl', 'w', encoding='utf-8')
    pbar = tqdm()
    for d in data:
        R = set([SPO(spo) for spo in extract_spoes(d)])
        if len(R) < 1:
            continue
        T = set()
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description(
            'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
        )
        s = json.dumps({
            'text': d,
            'spo_list': list(R),
        },
                       ensure_ascii=False)
        f.write(s + '\n')
    pbar.close()
    f.close()
    return f1, precision, recall
evaluate_data(word_list_50000)
```

    f1: 0.00000, precision: 0.00000, recall: 1.00000: : 83367it [4:16:07,  6.11it/s]


```python
model.load_weights('best_model.weights')
def evaluate_data(data):
    """评估函数,计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred.jsonl', 'w', encoding='utf-8')
    pbar = tqdm()
    for d in data:
        R = set([SPO(spo) for spo in extract_spoes(d)])
        if len(R) < 1:
            continue
        T = set()
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description(
            'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
        )
        s = json.dumps({
            'text': d,
            'spo_list': list(R),
        },
                       ensure_ascii=False)
        f.write(s + '\n')
    pbar.close()
    f.close()
    return f1, precision, recall
evaluate_data(word_list_10000)
```


```python
train_generator = data_generator(train_data, batch_size)
# evaluator = Evaluator()

model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=20,
        callbacks=[evaluator]
    )
```

    Epoch 1/20
    125/125 [==============================] - 67s 538ms/step - loss: 1.6920 - efficient_global_pointer_1_loss: 0.5173 - efficient_global_pointer_2_loss: 0.6092 - efficient_global_pointer_3_loss: 0.5656


    f1: 0.49655, precision: 0.48770, recall: 0.50572: : 500it [00:24, 20.01it/s]


    f1: 0.49655, precision: 0.48770, recall: 0.50572, best f1: 0.49871
    
    Epoch 2/20
    125/125 [==============================] - 68s 541ms/step - loss: 1.5598 - efficient_global_pointer_1_loss: 0.4498 - efficient_global_pointer_2_loss: 0.5779 - efficient_global_pointer_3_loss: 0.5321


    f1: 0.50217, precision: 0.49530, recall: 0.50923: : 500it [00:25, 19.56it/s]


    f1: 0.50217, precision: 0.49530, recall: 0.50923, best f1: 0.50217
    
    Epoch 3/20
    125/125 [==============================] - 67s 540ms/step - loss: 1.4841 - efficient_global_pointer_1_loss: 0.4538 - efficient_global_pointer_2_loss: 0.5419 - efficient_global_pointer_3_loss: 0.4883


    f1: 0.50766, precision: 0.50523, recall: 0.51011: : 500it [00:25, 19.76it/s]


    f1: 0.50766, precision: 0.50523, recall: 0.51011, best f1: 0.50766
    
    Epoch 4/20
    125/125 [==============================] - 66s 532ms/step - loss: 1.3896 - efficient_global_pointer_1_loss: 0.4203 - efficient_global_pointer_2_loss: 0.5053 - efficient_global_pointer_3_loss: 0.4639


    f1: 0.51189, precision: 0.51280, recall: 0.51099: : 500it [00:25, 19.70it/s]


    f1: 0.51189, precision: 0.51280, recall: 0.51099, best f1: 0.51189
    
    Epoch 5/20
    125/125 [==============================] - 68s 543ms/step - loss: 1.2606 - efficient_global_pointer_1_loss: 0.3801 - efficient_global_pointer_2_loss: 0.4629 - efficient_global_pointer_3_loss: 0.4177


    f1: 0.51218, precision: 0.51607, recall: 0.50836: : 500it [00:25, 19.84it/s]


    f1: 0.51218, precision: 0.51607, recall: 0.50836, best f1: 0.51218
    
    Epoch 6/20
    125/125 [==============================] - 67s 538ms/step - loss: 1.1541 - efficient_global_pointer_1_loss: 0.3397 - efficient_global_pointer_2_loss: 0.4406 - efficient_global_pointer_3_loss: 0.3738


    f1: 0.51337, precision: 0.52033, recall: 0.50660: : 500it [00:24, 20.07it/s]


    f1: 0.51337, precision: 0.52033, recall: 0.50660, best f1: 0.51337
    
    Epoch 7/20
    125/125 [==============================] - 68s 541ms/step - loss: 1.0183 - efficient_global_pointer_1_loss: 0.2941 - efficient_global_pointer_2_loss: 0.3768 - efficient_global_pointer_3_loss: 0.3474


    f1: 0.51544, precision: 0.52459, recall: 0.50660: : 500it [00:25, 19.86it/s]


    f1: 0.51544, precision: 0.52459, recall: 0.50660, best f1: 0.51544
    
    Epoch 8/20
    125/125 [==============================] - 68s 544ms/step - loss: 1.0351 - efficient_global_pointer_1_loss: 0.3146 - efficient_global_pointer_2_loss: 0.3885 - efficient_global_pointer_3_loss: 0.3320


    f1: 0.51959, precision: 0.53229, recall: 0.50748: : 500it [00:25, 19.69it/s]


    f1: 0.51959, precision: 0.53229, recall: 0.50748, best f1: 0.51959
    
    Epoch 9/20
    125/125 [==============================] - 67s 538ms/step - loss: 0.9599 - efficient_global_pointer_1_loss: 0.2925 - efficient_global_pointer_2_loss: 0.3578 - efficient_global_pointer_3_loss: 0.3095


    f1: 0.52022, precision: 0.53168, recall: 0.50923: : 500it [00:25, 19.90it/s]


    f1: 0.52022, precision: 0.53168, recall: 0.50923, best f1: 0.52022
    
    Epoch 10/20
    125/125 [==============================] - 67s 533ms/step - loss: 0.9140 - efficient_global_pointer_1_loss: 0.2787 - efficient_global_pointer_2_loss: 0.3511 - efficient_global_pointer_3_loss: 0.2843


    f1: 0.52299, precision: 0.53654, recall: 0.51011: : 500it [00:25, 19.65it/s]


    f1: 0.52299, precision: 0.53654, recall: 0.51011, best f1: 0.52299
    
    Epoch 11/20
    125/125 [==============================] - 67s 533ms/step - loss: 0.8556 - efficient_global_pointer_1_loss: 0.2696 - efficient_global_pointer_2_loss: 0.3154 - efficient_global_pointer_3_loss: 0.2706


    f1: 0.52385, precision: 0.53641, recall: 0.51187: : 500it [00:25, 19.71it/s]


    f1: 0.52385, precision: 0.53641, recall: 0.51187, best f1: 0.52385
    
    Epoch 12/20
    125/125 [==============================] - 67s 533ms/step - loss: 0.7335 - efficient_global_pointer_1_loss: 0.2199 - efficient_global_pointer_2_loss: 0.2733 - efficient_global_pointer_3_loss: 0.2403


    f1: 0.52351, precision: 0.53860, recall: 0.50923: : 500it [00:25, 19.82it/s]


    f1: 0.52351, precision: 0.53860, recall: 0.50923, best f1: 0.52385
    
    Epoch 13/20
    125/125 [==============================] - 67s 538ms/step - loss: 0.7204 - efficient_global_pointer_1_loss: 0.2177 - efficient_global_pointer_2_loss: 0.2536 - efficient_global_pointer_3_loss: 0.2491


    f1: 0.52407, precision: 0.54178, recall: 0.50748: : 500it [00:25, 19.85it/s]


    f1: 0.52407, precision: 0.54178, recall: 0.50748, best f1: 0.52407
    
    Epoch 14/20
    125/125 [==============================] - 67s 537ms/step - loss: 0.6730 - efficient_global_pointer_1_loss: 0.1986 - efficient_global_pointer_2_loss: 0.2597 - efficient_global_pointer_3_loss: 0.2147


    f1: 0.52465, precision: 0.54004, recall: 0.51011: : 500it [00:25, 19.68it/s]


    f1: 0.52465, precision: 0.54004, recall: 0.51011, best f1: 0.52465
    
    Epoch 15/20
    125/125 [==============================] - 67s 539ms/step - loss: 0.6692 - efficient_global_pointer_1_loss: 0.2074 - efficient_global_pointer_2_loss: 0.2523 - efficient_global_pointer_3_loss: 0.2095


    f1: 0.52655, precision: 0.54409, recall: 0.51011: : 500it [00:25, 19.71it/s]


    f1: 0.52655, precision: 0.54409, recall: 0.51011, best f1: 0.52655
    
    Epoch 16/20
     80/125 [==================>...........] - ETA: 23s - loss: 0.6411 - efficient_global_pointer_1_loss: 0.1935 - efficient_global_pointer_2_loss: 0.2297 - efficient_global_pointer_3_loss: 0.2179


```python

```


```python

```