作为一种脱敏后的数据集来说,我们一般的技术实现方案是首先需要训练一个预训练语言模型,然后再训练一个文本生成模型。 目前我训练了一天的一个结果 624576166@qq.com的团队 0.03219911 1 23-03-30 09:46 23-03-30 09:46 242名。
啊算了,不发baseline了。
预训练部分修改
def mlm_encode(text):
"""随机替换语料构建
"""
words = text.split(" ")
rands = np.random.random(len(words))
source, target = [tokenizer._token_start_id], [0]
for r, w in zip(rands, words):
ids = [int(w)]
if r < 0.15 * 0.8:
source.extend([tokenizer._token_mask_id] * len(ids))
target.extend(ids)
elif r < 0.15 * 0.9:
source.extend(ids)
target.extend(ids)
elif r < 0.15:
source.extend(
np.random.choice(tokenizer._vocab_size - 1, size=len(ids)) + 1
)
target.extend(ids)
else:
source.extend(ids)
target.extend([0] * len(ids))
source = source[:maxlen - 1] + [tokenizer._token_end_id]
target = target[:maxlen - 1] + [0]
return source, target
文本生成数据读取部分的修改
class data_generator(DataGenerator):
"""数据生成器
(每次只需要返回一条样本)
"""
def __iter__(self, random=False):
for is_end, (title, content) in self.sample(random):
segment_ids = []
token_ids = [101]
for content_one in content.split(" "):
token_ids.append(int(content_one))
token_ids.append(102)
for i in token_ids:
segment_ids.append(0)
for title_one in title.split(" "):
token_ids.append(int(title_one))
token_ids.append(102)
for i in range(len(token_ids)-len(segment_ids)):
segment_ids.append(1)
# token_ids, segment_ids = tokenizer.encode(
# content, title, maxlen=maxlen
# )
# 返回一条样本
yield token_ids, segment_ids
文本生成修改
class AutoTitle(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
return self.last_token(model).predict([token_ids, segment_ids])
def generate(self, text, topk=1):
max_c_len = maxlen - self.maxlen
# token_ids, segment_ids = tokenizer.encode(text, maxlen=max_c_len)
segment_ids = []
token_ids = [101]
for text_one in text.split(" "):
token_ids.append(int(text_one))
token_ids.append(102)
for i in token_ids:
segment_ids.append(0)
output_ids = self.beam_search([token_ids, segment_ids],
topk=topk) # 基于beam search
return output_ids
autotitle = AutoTitle(start_id=None, end_id=102, maxlen=128)
微信 15246115202 欢迎来撩 经主办方允许