持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第5天,点击查看活动详情
前言
先前文章分别介绍了数据采集、处理、模型代码以及部分训练代码,今天主要阅读训练代码,以及模型训练。
训练代码如下
import tensorflow as tf
import numpy as np
import os, argparse, time, random
from model import BiLSTM_CRF
from utils import str2bool, get_logger
from data_helper import read_dictionary, random_embedding, read_files ,read_tag_id
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
parser = argparse.ArgumentParser(description='BiLSTM-CRF for Chinese NER task') # 参数描述
parser.add_argument('--data_path', type=str,
default=r'E:\project\Vlun_NER_LSTM\NerData\train.txt',
help='train data source')
parser.add_argument('--word2id', type=str, default=r'E:\project\Vlun_NER_LSTM\NerData\word_dic.pkl',
help='word2id source')
parser.add_argument('--tag2id', type=str, default=r'E:\project\Vlun_NER_LSTM\NerData\label_dic.pkl',
help='word2id source')
parser.add_argument('--save_path', type=str,
default=r'E:\project\Vlun_NER_LSTM\data_path_save',
help='test data source')
parser.add_argument('--batch_size', type=int, default=32, help='#sample of each minibatch')
parser.add_argument('--epoch', type=int, default=30, help='#epoch of training')
parser.add_argument('--hidden_dim', type=int, default=300, help='#dim of hidden state')
parser.add_argument('--optimizer', type=str, default='Adam',
help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
parser.add_argument('--CRF', type=str2bool, default=True,
help='use CRF at the top layer. if False, use Softmax')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout keep_prob')
parser.add_argument('--update_embedding', type=str2bool, default=True,
help='update embedding during training')
parser.add_argument('--pretrain_embedding', type=str, default='random',
help='use pretrained char embedding or init it randomly')
parser.add_argument('--embedding_dim', type=int, default=300,
help='random init char embedding_dim')
parser.add_argument('--shuffle', type=str2bool, default=True,
help='shuffle training data before each epoch')
parser.add_argument('--mode', type=str, default='test', help='train/test/demo')
parser.add_argument('--demo_model', type=str, default='1662621164',
help='model for test and demo')
args = parser.parse_args()
word2id = read_dictionary(args.word2id)
tag2label = read_tag_id(args.tag2id)
if args.pretrain_embedding == 'random':
embeddings = random_embedding(word2id, args.embedding_dim)
else:
embedding_path = 'pretrain_embedding.npy'
embeddings = np.array(np.load(embedding_path), dtype='float32')
# -----------------read data--------------------
lines, label, seq_length = read_files(args.data_path)
assert len(lines) == len(label)
index = int(len(lines) * 0.9)
train_data, dev_data = lines[:index], lines[index:]
train_label, dev_label = label[:index], label[index:]
paths = {}
timestamp = str(int(time.time())) if args.mode == 'train' else args.demo_model
output_path = os.path.join(args.save_path, timestamp)
if not os.path.exists(output_path): os.makedirs(output_path)
summary_path = os.path.join(output_path, "summaries")
paths['summary_path'] = summary_path
if not os.path.exists(summary_path): os.makedirs(summary_path)
model_path = os.path.join(output_path, "checkpoints/")
if not os.path.exists(model_path): os.makedirs(model_path)
ckpt_prefix = os.path.join(model_path, "model")
paths['model_path'] = ckpt_prefix
result_path = os.path.join(output_path, "results")
paths['result_path'] = result_path
if not os.path.exists(result_path): os.makedirs(result_path)
log_path = os.path.join(result_path, "log.txt")
paths['log_path'] = log_path
get_logger(log_path).info(str(args))
if args.mode == 'train':
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("train data: {}".format(len(train_data)))
model.train(train_data, dev_data, train_label, dev_label)
## testing model
elif args.mode == 'test':
test_data = read_files(r"E:\project\Vlun_NER_LSTM\NerData\test.txt")
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("test data: {}".format(test_data))
model.test(test_data)
数据处理代码如下:
import pickle
import numpy as np
import codecs
import random
def pickle_writer(inputs, name):
output = open(name, 'wb')
pickle.dump(inputs, output, protocol=2)
output.close()
print("Finish save {}".format(name))
def pickle_reader(inputs):
f = open(inputs, 'rb')
lines = pickle.load(f)
f.close()
print("Finish load {}".format(inputs))
return lines
def read_dictionary(path):
return pickle_reader(path)
def read_tag_id(path):
return pickle_reader(path)
def random_embedding(vocab, embedding_dim):
embedding_mat = np.random.uniform(-0.25, 0.25, (len(vocab), embedding_dim))
embedding_mat = np.float32(embedding_mat)
return embedding_mat
def read_files(data_path):
temp = []
lines = []
label = []
labels = []
for line in codecs.open(data_path, 'r', 'UTF-8'):
if len(line.strip().split('\t')) < 2:
if temp and label:
lines.append(temp)
labels.append(label)
temp = []
label = []
else:
temp.append(line.strip().split('\t')[0])
label.append(line.strip().split('\t')[1])
seq_length = max([len(line) for line in lines])
return lines, labels, seq_length
def sentence2id(sent, word2id):
sentence_id = []
for word in sent:
if word not in word2id:
word = '[UNK]'
sentence_id.append(word2id[word])
return sentence_id
def batch_yield(data, batch_size, vocab, tag2label, shuffle=False):
temp = []
for i in range(len(data[0])):
temp.append([data[0][i],data[1][i]])
temp = np.array(temp)
if shuffle:
random.shuffle(temp)
seqs, labels = [], []
for (sent_, tag_) in temp:
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
if len(seqs) != 0:
yield seqs, labels
def pad_sequences(sequences, pad_mark=0):
max_len = max(map(lambda x : len(x), sequences))
seq_list, seq_len_list = [], []
for seq in sequences:
seq = list(seq)
seq_ = seq[:max_len] + [pad_mark] * max(max_len - len(seq), 0)
seq_list.append(seq_)
seq_len_list.append(min(len(seq), max_len))
return seq_list, seq_len_list
参数理解:
config.gpu_options.allow_growth = True
表示不会占满显存,目的在于多用户共用GPU资源的时候,避免单一代码占满GPU资源导致其他用户无资源可用。
- data_path 表示训练数据路径
- word2id 表示字符的索引值,为之前代码生成的字符-索引文件,即word_dic.pkl
- tag2id 表示标签-索引文件
- save_path 表示模型保存路径
- batch_size 批次大小,训练的每一个step馈送入模型的数量条目
- epoch 模型训练轮数,模型使用全部的训练数据全部训练完毕为一轮,一共训练多少轮
- hidden_dim 隐层大小,LSTM模型使用的隐层单元数量,数量越多,模型参数量越大
- optimizer 优化器选择,默认使用adam
- CRF 表示是否使用CRF作为序列预测
- lr 表示使用的学习率,一般使用0.001等浮点数作为学习率,用于梯度计算,反向传播更新参数
- clip 梯度阈值 避免反向传播过程中的梯度爆炸 当梯度超过梯度阈值的时候直接进行剪裁
- dropout 神经网络随机失活比例,用于减少过拟合
- update_embedding 是否更新embedding矩阵
- pretrain_embedding 是否使用预训练词向量
- embedding_dim embedding矩阵维度
- shuffle 是否随机打乱数据集
- mode 选择训练、测试还是预测
- demo_model 测试模型的checkpoint
加载词典
word2id = read_dictionary(args.word2id)
tag2label = read_tag_id(args.tag2id)
def read_dictionary(path):
return pickle_reader(path)
def read_tag_id(path):
return pickle_reader(path)
加载word2id、tag2id词典,均通过pickle直接加载持久化的文件。
初始化embedding矩阵
if args.pretrain_embedding == 'random':
embeddings = random_embedding(word2id, args.embedding_dim)
else:
embedding_path = 'pretrain_embedding.npy'
embeddings = np.array(np.load(embedding_path), dtype='float32')
如果是随机初始化的embedding矩阵,那么通过numpy数组进行生成,维度为词典大小 * embedding_size。
def random_embedding(vocab, embedding_dim):
embedding_mat = np.random.uniform(-0.25, 0.25, (len(vocab), embedding_dim))
embedding_mat = np.float32(embedding_mat)
return embedding_mat
读取数据
# -----------------read data--------------------
lines, label, seq_length = read_files(args.data_path)
assert len(lines) == len(label)
index = int(len(lines) * 0.9)
train_data, dev_data = lines[:index], lines[index:]
train_label, dev_label = label[:index], label[index:]
加载先前生成的训练数据,获取数据、标签、文本长度 切分训练集和测试集,采用9:1的方式即训练数据9份、测试数据1份的方式
output相关
paths = {}
timestamp = str(int(time.time())) if args.mode == 'train' else args.demo_model
output_path = os.path.join(args.save_path, timestamp)
if not os.path.exists(output_path): os.makedirs(output_path)
summary_path = os.path.join(output_path, "summaries")
paths['summary_path'] = summary_path
if not os.path.exists(summary_path): os.makedirs(summary_path)
model_path = os.path.join(output_path, "checkpoints/")
if not os.path.exists(model_path): os.makedirs(model_path)
ckpt_prefix = os.path.join(model_path, "model")
paths['model_path'] = ckpt_prefix
result_path = os.path.join(output_path, "results")
paths['result_path'] = result_path
if not os.path.exists(result_path): os.makedirs(result_path)
log_path = os.path.join(result_path, "log.txt")
paths['log_path'] = log_path
get_logger(log_path).info(str(args))
生成输入路径或者文件夹,如模型输出路径,摘要输出路径等等
train
if args.mode == 'train':
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("train data: {}".format(len(train_data)))
model.train(train_data, dev_data, train_label, dev_label)
如果模型为训练模式,首先通过BiLSTM_CRF初始化模型,构建graph,之后调用train函数按批次,epoch等参数训练样本
test
## testing model
elif args.mode == 'test':
test_data = read_files(r"{projectPath}\NerData\test.txt")
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("test data: {}".format(test_data))
model.test(test_data)
如果模型选择为验证,则加载test.txt,根据句子预测标签,再与正确标签计算正确率与loss等指标。
训练过程
通过tensorboard加载回放训练过程
tensorboard --logdir= {projectPath}\data_path_save\1662621164\summaries\
使用上述命名报错tensorboard: error: invalid choice: '{projectPath}\data_path_save\1662621164\summaries\' (choose from 'serve', 'dev')
正确用法为
tensorboard --logdir "{projectPath}\\data_path_save\\1662621164\\summaries"
正确输出为
Skipping registering GPU devices...
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.10.0 at http://localhost:6006/ (Press CTRL+C to quit)
可以看见训练过程中loss逐渐降低。最终趋于收敛。
最终模型的正确率为0.97左右,召回率在0.83左右。待优化,本期主要记录了模型的参数以及训练过程,下一篇主要记录模型加载以及测试样本标签生成等,蟹蟹~