基于LSTM的文本生成器

185 阅读4分钟

基本介绍

LSTM是一种被人们熟知的时序预测模型,而语言, 音频都是与时间有关的,由此我们可以通过时序的上一个状态预测下一个状态, 这种问题我们可以尝试使用时序模型

模型搭建如下

graph LR
	embedding --> lstm --> linear

当然没法堆参数量也就是这种模型搭起来也就是玩玩(

基本符号

<sos> : start of sentence

<eos>: end of sentence

<unk>: unknown, 表示出现频率非常低的单词

训练前需要做什么

需要加载之前训好的词向量 , 通过这一段代码实现

# word2vec.py
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = get_tokenizer("basic_english")
train_iter, _, _ = PennTreebank()
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>', '<sos>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])
EMBEDDING_DIM = 200
VOCAB_SIZE = len(vocab)
EPOCHS = 3
BATCH_SIZE = 16
INTERVAL = 400


class PTBDataset(Dataset):
    def __init__(self, data, vocab, window_size=1):
        self.data = list(data)
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.window_size = window_size
        self.pairs = self.create_pairs()

    def create_pairs(self):
        pairs = []
        for sentence in self.data:
            words = tokenizer(sentence)
            indices = [self.vocab[word] for word in words]
            for center_word_pos in range(len(indices)):
                center_word_idx = indices[center_word_pos]
                context_word_idx = list()
                for w in range(-self.window_size, self.window_size + 1):
                    if w == 0:
                        continue
                    context_word_pos = center_word_pos + w
                    if context_word_pos >= len(indices) or context_word_pos < 0:
                        continue
                    context_word_idx.append(indices[context_word_pos])
                pairs.append((center_word_idx, context_word_idx))
        return pairs

    def to_tensor(self, i):
        center, context = self.pairs[i]
        center_idx = torch.tensor(center, device=device)
        context_tensor = torch.zeros(self.vocab_size, device=device)
        context_tensor[context] = 1
        return center_idx, context_tensor

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.to_tensor(idx)


class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        self.linear.weight = torch.nn.Parameter(self.embeddings.weight)

    def forward(self, center):
        c0 = self.embeddings(center)
        c0 = self.linear(c0)
        return c0


dataset = PTBDataset(train_iter, vocab, window_size=1)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = Word2Vec(VOCAB_SIZE, EMBEDDING_DIM).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

loss_list = list()
start_time = time.time()
print('start training...')
for epoch in range(1, EPOCHS + 1):
    total_loss = 0
    iter_loss = 0
    cnt = 0
    for center, context in dataloader:
        optimizer.zero_grad()
        output = model(center)
        loss = criterion(output, context)
        loss.backward()
        optimizer.step()
        iter_loss += loss.item()
        cnt += 1
        if cnt % INTERVAL == 0:
            avg_loss = iter_loss / INTERVAL
            print(f"-epoch: {epoch}\t -loss: {avg_loss:.5f} -"
                  f"time:{time.time() - start_time:.2f}s")
            loss_list.append(avg_loss)
            iter_loss = 0

torch.save(model.embeddings.weight.data, r"word_embeddings.pth")

fig = plt.figure()
ax = plt.plot(loss_list)
plt.show()
文本生成模型训练
# RNNLM.py
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import get_tokenizer
from torch.utils.data.dataset import Dataset
from torchtext.datasets import PennTreebank

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iter, val_iter, test_iter = PennTreebank()
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>', '<sos>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])
w2id = vocab.get_stoi()
id2w = vocab.get_itos()
EMBEDDING_DIM = 200
DIFF_LENGTH = 35
VOCAB_SIZE = len(w2id)
BATCH_SIZE = 32
EPOCHS = 30


class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, embedding_weight=None):
        super(LanguageModel, self).__init__()
        if embedding_weight is None:
            share_weight = torch.randn(vocab_size, embedding_dim, requires_grad=False)
        else:
            share_weight = embedding_weight
            share_weight.requires_grad = False
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, embedding_dim, batch_first=True, num_layers=3)
        self.linear = nn.Linear(embedding_dim, vocab_size)

        self.embedding.weight = nn.Parameter(share_weight)
        self.linear.weight = nn.Parameter(share_weight)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x


class TextDataset(Dataset):
    def __init__(self, text, embedding_dim, w2id, id2w, min_len=10):
        super(TextDataset, self).__init__()
        self.lines = text
        self.embedding_dim = embedding_dim
        self.w2id = w2id
        self.id2w = id2w
        self.avg_len = DIFF_LENGTH
        self.min_len = min_len
        self.corpus = self.get_corpus()
        self.len = len(self.corpus) // self.avg_len

    def get_corpus(self):
        corpus = []
        for line in self.lines:
            tokens = tokenizer('<sos>' + line + '<eos>')
            if len(tokens) < self.min_len:
                continue
            indicates = [w2id.get(token, 0) for token in tokens]
            corpus += indicates
        return corpus

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        idx = idx * self.avg_len
        prv = self.corpus[idx: idx + self.avg_len - 1]
        nxt = self.corpus[idx + 1: idx + self.avg_len]
        prv_t = torch.tensor(prv, device=device)
        nxt_t = torch.tensor(nxt, device=device)
        return prv_t, nxt_t


if __name__ == '__main__':
    dataset = TextDataset(train_iter, EMBEDDING_DIM, w2id, id2w)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    embedding_weight = torch.load('word_embeddings.pth')
    model = LanguageModel(VOCAB_SIZE, EMBEDDING_DIM, embedding_weight).to(device)
    optim = optim.NAdam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    start = time.time()
    loss_list = []
    for epoch in range(1, EPOCHS + 1):
        cnt = 0
        sum_loss = 0
        for prv, nxt in dataloader:
            optim.zero_grad()
            pred = model(prv)
            pred = pred.view(-1, VOCAB_SIZE)
            nxt = nxt.view(-1)
            loss = criterion(pred, nxt)
            loss.backward()
            optim.step()
            cnt += 1
            sum_loss += loss.item()
            if cnt % 50 == 0:
                avg_loss = sum_loss / 50
                sum_loss = 0
                print(f'epoch{epoch:3d} iter{cnt:5d}/{len(dataloader)} loss:{avg_loss:.6f}'
                      f'\ttime {time.time() - start:6.2f}s')
                loss_list.append(avg_loss)
    plt.figure()
    plt.plot(loss_list)
    plt.savefig('loss.png')

    torch.save(model.state_dict(), './model.pth')
加载模型
import random
from RNNLM import *

model = LanguageModel(VOCAB_SIZE, EMBEDDING_DIM)
model.load_state_dict(torch.load("./model.pth"))


def generate_text(model, start_text, tokenizer, w2id, id2w, max_len=20, k=3):
    print(start_text, end=' ')
    start_text = '<sos> ' + start_text
    tokens = tokenizer(start_text)
    indices = [w2id[token] for token in tokens]
    model.eval()

    while len(indices) < max_len:
        input_tensor = torch.tensor(indices)
        with torch.no_grad():
            output = model(input_tensor)
        data, idx = torch.topk(output, k=k, dim=1)
        idx_list = idx[-1].view(-1).tolist()
        next_word_idx = 0

        cnt = 0
        while next_word_idx < 3:
            next_word_idx = random.choice(idx_list)
            cnt += 1
            if next_word_idx < 3 and cnt >= k:
                break
        if next_word_idx < 3 and cnt >= k:
            break

        indices.append(next_word_idx)
        next_word = id2w[next_word_idx]
        print(next_word,  end=' ')
    print('\n', indices, '\t len:', len(indices))


if __name__ == '__main__':
    start_text = 'you'
    generate_text(model, start_text, tokenizer, w2id, id2w, k=3, max_len=20)

生成的文本比较弱智, 但是至少有点语法观念

eb7df43119f7db3441029a4c6c2d24a0.png

当然大小就这么点, 不苛求了

image.png