携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第19天,点击查看活动详情
1 概览(Overview)
Fairseq可以通过用户提供的插件进行扩展。其支持五种插件:
- Models定义神经网络架构,并封装所有可学习的参数。
- Criterions计算给定模型输出和目标的损失函数。
- Tasks存储字典并为加载/遍历数据集、初始化模型/标准和计算损失提供帮助。
- Optimizers根据梯度更新模型参数。
- Learning Rate Schedulers在训练过程中更新学习率。
训练流程(Training Flow)
给定一个model、criterion、task、optimizer和lr_scheduler,Fairseq实现了以下高级训练流程:
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
其中task.train_step的默认实现大概是:
def train_step(self, batch, model, criterion, optimizer, **unused):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss
注册新插件(Registering new plug-ins)
例如,通过一组@register函数装饰器注册新插件:
@register_model('my_lstm')
class MyLSTM(FairseqEncoderDecoderModel):
(...)
一旦注册,新的插件就可以与现有的命令行工具一起使用。有关如何添加新插件的更详细的演练,请参阅下一部分。
从另一个目录加载插件(Loading plug-ins from another directory)
可以在用户系统中存储的自定义模块中定义新的插件。为了导入模块,并使插件对fairseq可用,命令行支持--user-dir标志,该标志可用于为加载到fairseq中的其他模块指定一个自定义位置。
例如,假设这个目录树:
/home/user/my-module/
└── __init__.py
其中,_init_.py的内容为:
from fairseq.models import register_model_architecture
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
@register_model_architecture('transformer', 'my_transformer')
def transformer_mmt_big(args):
transformer_vaswani_wmt_en_de_big(args)
可以使用新的体系结构调用fairseq-train脚本:
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
2 教程:简单的LSTM(Tutorial: Simple LSTM)
在本教程中,我们将扩展fairseq,添加一个新的FairseqEncoderDecoderModel,它使用LSTM编码源句子,然后将最终的隐藏状态传递给第二个LSTM,该LSTM解码目标句子(不需要Attention)。
本教程包括:
- 编写编码器和解码器分别编码/解码源/目标句子。
- 注册一个新模型,以便它可以与现有的命令行工具一起使用。
- 使用现有的命令行工具培训模型。
- 通过修改解码器以使用增量解码来加快生成速度。
2.1 构建编码器和解码器(Building an Encoder and Decoder)
在本节中,我们将定义一个简单的LSTM编码器和解码器。所有编码器应该实现FairseqEncoder接口,解码器应该实现FairseqDecoder接口。这些接口本身扩展了torch.nn模块,因此FairseqEncoders和FairseqDecoders可以以与普通PyTorch模块相同的方式编写和使用。
2.1.1 编码器(Encoder)
我们的Encoder将把tokens嵌入到源句子中,将它们提供给torch.nn.LSTM,并返回最终的隐藏状态。要创建编码器,请将以下内容保存在名为fairseq/models/simple_lstm.py的新文件中:
import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder
class SimpleLSTMEncoder(FairseqEncoder):
def __init__(
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
):
super().__init__(dictionary)
self.args = args
# Our encoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
batch_first=True,
)
def forward(self, src_tokens, src_lengths):
# The inputs to the ``forward()`` function are determined by the
# Task, and in particular the ``'net_input'`` key in each
# mini-batch. We discuss Tasks in the next tutorial, but for now just
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
# has shape `(batch)`.
# Note that the source is typically padded on the left. This can be
# configured by adding the `--left-pad-source "False"` command-line
# argument, but here we'll make the Encoder handle either kind of
# padding by converting everything to be right-padded.
if self.args.left_pad_source:
# Convert left-padding to right-padding.
src_tokens = utils.convert_padding_direction(
src_tokens,
padding_idx=self.dictionary.pad(),
left_to_right=True
)
# Embed the source.
x = self.embed_tokens(src_tokens)
# Apply dropout.
x = self.dropout(x)
# Pack the sequence into a PackedSequence object to feed to the LSTM.
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
# Get the output from the LSTM.
_outputs, (final_hidden, _final_cell) = self.lstm(x)
# Return the Encoder's output. This can be any object and will be
# passed directly to the Decoder.
return {
# this will have shape `(bsz, hidden_dim)`
'final_hidden': final_hidden.squeeze(0),
}
# Encoders are required to implement this method so that we can rearrange
# the order of the batch elements during inference (e.g., beam search).
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
final_hidden = encoder_out['final_hidden']
return {
'final_hidden': final_hidden.index_select(0, new_order),
}
2.1.2 解码器(Decoder)
我们的解码器将预测下一个单词,条件是编码器的最终隐藏状态和前一个目标单词的嵌入表示,有时称为teacher-forcing。更具体地说,我们将使用torch.nn.LSTM生成一系列隐藏状态,我们将把这些状态投影到输出词汇表的大小中,以预测每个目标单词。
import torch
from fairseq.models import FairseqDecoder
class SimpleLSTMDecoder(FairseqDecoder):
def __init__(
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
dropout=0.1,
):
super().__init__(dictionary)
# Our decoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
# For the first layer we'll concatenate the Encoder's final hidden
# state with the embedded target tokens.
input_size=encoder_hidden_dim + embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
# Define the output projection.
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
# During training Decoders are expected to take the entire target sequence
# (shifted right by one position) and produce logits over the vocabulary.
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
# ``dictionary.eos()``, followed by the target sequence.
def forward(self, prev_output_tokens, encoder_out):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
"""
bsz, tgt_len = prev_output_tokens.size()
# Extract the final hidden state from the Encoder.
final_encoder_hidden = encoder_out['final_hidden']
# Embed the target sequence, which has been shifted right by one
# position and now starts with the end-of-sentence symbol.
x = self.embed_tokens(prev_output_tokens)
# Apply dropout.
x = self.dropout(x)
# Concatenate the Encoder's final hidden state to *every* embedded
# target token.
x = torch.cat(
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
dim=2,
)
# Using PackedSequence objects in the Decoder is harder than in the
# Encoder, since the targets are not sorted in descending length order,
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
# feed nn.LSTM directly.
initial_state = (
final_encoder_hidden.unsqueeze(0), # hidden
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
)
output, _ = self.lstm(
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
initial_state,
)
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
# Project the outputs to the size of the vocabulary.
x = self.output_projection(x)
# Return the logits and ``None`` for the attention weights
return x, None