基本介绍
LSTM是一种被人们熟知的时序预测模型,而语言, 音频都是与时间有关的,由此我们可以通过时序的上一个状态预测下一个状态, 这种问题我们可以尝试使用时序模型
模型搭建如下
graph LR
embedding --> lstm --> linear
当然没法堆参数量也就是这种模型搭起来也就是玩玩(
基本符号
<sos> : start of sentence
<eos>: end of sentence
<unk>: unknown, 表示出现频率非常低的单词
训练前需要做什么
需要加载之前训好的词向量 , 通过这一段代码实现
# word2vec.py
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = get_tokenizer("basic_english")
train_iter, _, _ = PennTreebank()
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>', '<sos>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])
EMBEDDING_DIM = 200
VOCAB_SIZE = len(vocab)
EPOCHS = 3
BATCH_SIZE = 16
INTERVAL = 400
class PTBDataset(Dataset):
def __init__(self, data, vocab, window_size=1):
self.data = list(data)
self.vocab = vocab
self.vocab_size = len(vocab)
self.window_size = window_size
self.pairs = self.create_pairs()
def create_pairs(self):
pairs = []
for sentence in self.data:
words = tokenizer(sentence)
indices = [self.vocab[word] for word in words]
for center_word_pos in range(len(indices)):
center_word_idx = indices[center_word_pos]
context_word_idx = list()
for w in range(-self.window_size, self.window_size + 1):
if w == 0:
continue
context_word_pos = center_word_pos + w
if context_word_pos >= len(indices) or context_word_pos < 0:
continue
context_word_idx.append(indices[context_word_pos])
pairs.append((center_word_idx, context_word_idx))
return pairs
def to_tensor(self, i):
center, context = self.pairs[i]
center_idx = torch.tensor(center, device=device)
context_tensor = torch.zeros(self.vocab_size, device=device)
context_tensor[context] = 1
return center_idx, context_tensor
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.to_tensor(idx)
class Word2Vec(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(Word2Vec, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
self.linear.weight = torch.nn.Parameter(self.embeddings.weight)
def forward(self, center):
c0 = self.embeddings(center)
c0 = self.linear(c0)
return c0
dataset = PTBDataset(train_iter, vocab, window_size=1)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = Word2Vec(VOCAB_SIZE, EMBEDDING_DIM).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
loss_list = list()
start_time = time.time()
print('start training...')
for epoch in range(1, EPOCHS + 1):
total_loss = 0
iter_loss = 0
cnt = 0
for center, context in dataloader:
optimizer.zero_grad()
output = model(center)
loss = criterion(output, context)
loss.backward()
optimizer.step()
iter_loss += loss.item()
cnt += 1
if cnt % INTERVAL == 0:
avg_loss = iter_loss / INTERVAL
print(f"-epoch: {epoch}\t -loss: {avg_loss:.5f} -"
f"time:{time.time() - start_time:.2f}s")
loss_list.append(avg_loss)
iter_loss = 0
torch.save(model.embeddings.weight.data, r"word_embeddings.pth")
fig = plt.figure()
ax = plt.plot(loss_list)
plt.show()
文本生成模型训练
# RNNLM.py
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import get_tokenizer
from torch.utils.data.dataset import Dataset
from torchtext.datasets import PennTreebank
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iter, val_iter, test_iter = PennTreebank()
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>', '<sos>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])
w2id = vocab.get_stoi()
id2w = vocab.get_itos()
EMBEDDING_DIM = 200
DIFF_LENGTH = 35
VOCAB_SIZE = len(w2id)
BATCH_SIZE = 32
EPOCHS = 30
class LanguageModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, embedding_weight=None):
super(LanguageModel, self).__init__()
if embedding_weight is None:
share_weight = torch.randn(vocab_size, embedding_dim, requires_grad=False)
else:
share_weight = embedding_weight
share_weight.requires_grad = False
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, embedding_dim, batch_first=True, num_layers=3)
self.linear = nn.Linear(embedding_dim, vocab_size)
self.embedding.weight = nn.Parameter(share_weight)
self.linear.weight = nn.Parameter(share_weight)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.linear(x)
return x
class TextDataset(Dataset):
def __init__(self, text, embedding_dim, w2id, id2w, min_len=10):
super(TextDataset, self).__init__()
self.lines = text
self.embedding_dim = embedding_dim
self.w2id = w2id
self.id2w = id2w
self.avg_len = DIFF_LENGTH
self.min_len = min_len
self.corpus = self.get_corpus()
self.len = len(self.corpus) // self.avg_len
def get_corpus(self):
corpus = []
for line in self.lines:
tokens = tokenizer('<sos>' + line + '<eos>')
if len(tokens) < self.min_len:
continue
indicates = [w2id.get(token, 0) for token in tokens]
corpus += indicates
return corpus
def __len__(self):
return self.len
def __getitem__(self, idx):
idx = idx * self.avg_len
prv = self.corpus[idx: idx + self.avg_len - 1]
nxt = self.corpus[idx + 1: idx + self.avg_len]
prv_t = torch.tensor(prv, device=device)
nxt_t = torch.tensor(nxt, device=device)
return prv_t, nxt_t
if __name__ == '__main__':
dataset = TextDataset(train_iter, EMBEDDING_DIM, w2id, id2w)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
embedding_weight = torch.load('word_embeddings.pth')
model = LanguageModel(VOCAB_SIZE, EMBEDDING_DIM, embedding_weight).to(device)
optim = optim.NAdam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
start = time.time()
loss_list = []
for epoch in range(1, EPOCHS + 1):
cnt = 0
sum_loss = 0
for prv, nxt in dataloader:
optim.zero_grad()
pred = model(prv)
pred = pred.view(-1, VOCAB_SIZE)
nxt = nxt.view(-1)
loss = criterion(pred, nxt)
loss.backward()
optim.step()
cnt += 1
sum_loss += loss.item()
if cnt % 50 == 0:
avg_loss = sum_loss / 50
sum_loss = 0
print(f'epoch{epoch:3d} iter{cnt:5d}/{len(dataloader)} loss:{avg_loss:.6f}'
f'\ttime {time.time() - start:6.2f}s')
loss_list.append(avg_loss)
plt.figure()
plt.plot(loss_list)
plt.savefig('loss.png')
torch.save(model.state_dict(), './model.pth')
加载模型
import random
from RNNLM import *
model = LanguageModel(VOCAB_SIZE, EMBEDDING_DIM)
model.load_state_dict(torch.load("./model.pth"))
def generate_text(model, start_text, tokenizer, w2id, id2w, max_len=20, k=3):
print(start_text, end=' ')
start_text = '<sos> ' + start_text
tokens = tokenizer(start_text)
indices = [w2id[token] for token in tokens]
model.eval()
while len(indices) < max_len:
input_tensor = torch.tensor(indices)
with torch.no_grad():
output = model(input_tensor)
data, idx = torch.topk(output, k=k, dim=1)
idx_list = idx[-1].view(-1).tolist()
next_word_idx = 0
cnt = 0
while next_word_idx < 3:
next_word_idx = random.choice(idx_list)
cnt += 1
if next_word_idx < 3 and cnt >= k:
break
if next_word_idx < 3 and cnt >= k:
break
indices.append(next_word_idx)
next_word = id2w[next_word_idx]
print(next_word, end=' ')
print('\n', indices, '\t len:', len(indices))
if __name__ == '__main__':
start_text = 'you'
generate_text(model, start_text, tokenizer, w2id, id2w, k=3, max_len=20)
生成的文本比较弱智, 但是至少有点语法观念
当然大小就这么点, 不苛求了