Lucidrains-系列项目源码解析-六十三-

23 阅读18分钟

Lucidrains 系列项目源码解析(六十三)

Multistream Transformers

Implementation of Multistream Transformers in Pytorch.

This repository deviates slightly from the paper, where instead of using the skip connection across all streams, it uses attention pooling across all tokens in the same position. This has produced the best results in my experiments with number of streams greater than 2.

Install

$ pip install multistream-transformers

Usage

import torch
from multistream_transformers import MultistreamTransformer

model = MultistreamTransformer(
    num_tokens = 256,         # number of tokens
    dim = 512,                # dimension
    depth = 4,                # depth
    causal = True,            # autoregressive or not
    max_seq_len = 1024,       # maximum sequence length
    num_streams = 2           # number of streams - 1 would make it a regular transformer
)

x = torch.randint(0, 256, (2, 1024))
mask = torch.ones((2, 1024)).bool()

logits = model(x, mask = mask) # (2, 1024, 256)

Citations

@misc{burtsev2021multistream,
    title   = {Multi-Stream Transformers}, 
    author  = {Mikhail Burtsev and Anna Rumshisky},
    year    = {2021},
    eprint  = {2107.10342},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\multistream-transformers\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'multistream-transformers',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Multistream Transformers - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/multistream-transformers',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\multistream-transformers\train.py

# 导入所需的库
from multistream_transformers import MultistreamTransformer
from multistream_transformers.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型

model = MultistreamTransformer(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = 4,
    heads = 8,
    causal = True,
    num_streams = 2
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\attend.py

# 导入所需的模块和函数
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入自定义的 FlashAttentionFunction 函数
from memory_efficient_attention_pytorch.flash_attention import FlashAttentionFunction

# 定义一个命名元组 AttentionConfig,包含三个布尔类型的字段
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只能调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个只能打印一次的函数
print_once = once(print)

# 主要类定义
class Attend(nn.Module):
    def __init__(
        self,
        scale = 8,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.scale = scale
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        # 检查是否启用了 flash attention,且 PyTorch 版本是否大于等于 2.0
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置
        self.cuda_config = None
        self.no_hardware_detected = False

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, False)

    # 定义 flash attention 函数
    def flash_attn(self, q, k, v, mask = None):
        default_scale = q.shape[-1] ** -0.5

        is_cuda = q.is_cuda

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # 重新缩放输入张量以适应默认缩放比例
        rescale = self.scale / default_scale
        q = q * (rescale ** 0.5)
        k = k * (rescale ** 0.5)

        # 如果没有检测到正确的硬件或不在 CUDA 上,则使用简单的实现
        use_naive = not is_cuda or not exists(self.cuda_config)

        if not is_cuda or self.no_hardware_detected:
            return FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)

        # 尝试使用 PyTorch 2.0 的 flash attention 实现
        try:
            raise Exception()
            with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()):
                out = F.scaled_dot_product_attention(
                    q, k, v,
                    attn_mask = mask,
                    dropout_p = self.dropout if self.training else 0.
                )
        except:
            print_once('no hardware detected, falling back to naive implementation from memory-efficient-attention-pytorch library')
            self.no_hardware_detected = True

            out = FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)、掩码(mask)和是否强制非闪存(force_non_flash)作为参数
    def forward(self, q, k, v, mask = None, force_non_flash = False):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 如果启用了flash且不强制使用非flash,则调用flash_attn函数
        if self.flash and not force_non_flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 计算相似度
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        # 掩码处理
        if exists(mask):
            mask_value = -torch.finfo(sim.dtype).max
            sim = sim.masked_fill(~mask, mask_value)

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\muse_maskgit_pytorch.py

        # 定义一个注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        cross_attend = False,
        scale = 8,
        flash = True,
        dropout = 0.
    ):
        super().__init__()
        # 缩放因子
        self.scale = scale
        # 头数
        self.heads =  heads
        # 内部维度
        inner_dim = dim_head * heads

        # 是否进行跨注意力
        self.cross_attend = cross_attend
        # 归一化层
        self.norm = LayerNorm(dim)

        # 注意力机制
        self.attend = Attend(
            flash = flash,
            dropout = dropout,
            scale = scale
        )

        # 空键值对
        self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))

        # 转换查询
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 转换键值对
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 查询缩放
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        # 键缩放
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出转换
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        context_mask = None
        ):
        # 断言条件:如果存在上下文信息,则不应该使用交叉注意力,反之亦然
        assert not (exists(context) ^ self.cross_attend)

        # 获取输入张量 x 的倒数第二维度的大小
        n = x.shape[-2]
        # 获取头数 h 和是否使用交叉注意力 is_cross_attn
        h, is_cross_attn = self.heads, exists(context)

        # 对输入张量 x 进行归一化处理
        x = self.norm(x)

        # 根据是否使用交叉注意力选择键值对输入
        kv_input = context if self.cross_attend else x

        # 分别计算查询 q、键 k、值 v,并根据最后一维度拆分成三部分
        q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

        # 将查询 q、键 k、值 v 重排维度,使得头数 h 在第二维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 获取空键值对 nk、nv,并根据头数 h 和批次大小重复扩展
        nk, nv = self.null_kv
        nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = x.shape[0]), (nk, nv))

        # 将键 k 和值 v 连接空键值对 nk、nv
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 对查询 q、键 k 进行 L2 归一化处理
        q, k = map(l2norm, (q, k))
        # 对查询 q、键 k 进行缩放
        q = q * self.q_scale
        k = k * self.k_scale

        # 如果存在上下文掩码,则重复扩展到匹配注意力矩阵的维度,并进行填充
        if exists(context_mask):
            context_mask = repeat(context_mask, 'b j -> b h i j', h = h, i = n)
            context_mask = F.pad(context_mask, (1, 0), value = True)

        # 进行注意力计算
        out = self.attend(q, k, v, mask = context_mask)

        # 重排输出维度,使得头数 h 在第二维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出结果
        return self.to_out(out)
# 定义 TransformerBlocks 类,用于堆叠多个 Transformer 模块
class TransformerBlocks(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        depth,  # 堆叠的 Transformer 模块数量
        dim_head = 64,  # 注意力头的维度
        heads = 8,  # 注意力头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        flash = True  # 是否使用 Flash
    ):
        super().__init__()
        self.layers = nn.ModuleList([])  # 初始化空的模块列表

        for _ in range(depth):  # 根据 depth 循环堆叠 Transformer 模块
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash),  # 添加注意力模块
                Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True, flash = flash),  # 添加交叉注意力模块
                FeedForward(dim = dim, mult = ff_mult)  # 添加 FeedForward 模块
            ]))

        self.norm = LayerNorm(dim)  # 初始化 LayerNorm 模块

    def forward(self, x, context = None, context_mask = None):  # 前向传播函数
        for attn, cross_attn, ff in self.layers:  # 遍历每个 Transformer 模块
            x = attn(x) + x  # 执行注意力模块并加上残差连接

            x = cross_attn(x, context = context, context_mask = context_mask) + x  # 执行交叉注意力模块并加上残差连接

            x = ff(x) + x  # 执行 FeedForward 模块并加上残差连接

        return self.norm(x)  # 返回 LayerNorm 后的结果

# 定义 Transformer 类,用于处理文本数据
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        dim,  # 输入维度
        seq_len,  # 序列长度
        dim_out = None,  # 输出维度
        t5_name = DEFAULT_T5_NAME,  # T5 模型名称
        self_cond = False,  # 是否自我条件
        add_mask_id = False,  # 是否添加 mask 标记
        **kwargs
    ):
        super().__init__()
        self.dim = dim  # 初始化输入维度
        self.mask_id = num_tokens if add_mask_id else None  # 初始化 mask 标记

        self.num_tokens = num_tokens  # 初始化标记数量
        self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim)  # 初始化标记嵌入层
        self.pos_emb = nn.Embedding(seq_len, dim)  # 初始化位置嵌入层
        self.seq_len = seq_len  # 初始化序列长度

        self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs)  # 初始化 TransformerBlocks 模块
        self.norm = LayerNorm(dim)  # 初始化 LayerNorm 模块

        self.dim_out = default(dim_out, num_tokens)  # 初始化输出维度
        self.to_logits = nn.Linear(dim, self.dim_out, bias = False)  # 初始化线性层

        # 文本条件

        self.encode_text = partial(t5_encode_text, name = t5_name)  # 编码文本

        text_embed_dim = get_encoded_dim(t5_name)  # 获取编码后的文本维度

        self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity()  # 初始化文本嵌入层

        # 可选的自我条件

        self.self_cond = self_cond  # 初始化自我条件
        self.self_cond_to_init_embed = FeedForward(dim)  # 初始化 FeedForward 模块

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3.,  # 条件缩放因子
        return_embed = False,  # 是否返回嵌入
        **kwargs
    ):
        if cond_scale == 1:  # 如果条件缩放因子为1
            return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播

        logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs)  # 执行前向传播并返回嵌入

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)  # 执行前向传播,使用条件丢弃

        scaled_logits = null_logits + (logits - null_logits) * cond_scale  # 计算缩放后的 logits

        if return_embed:  # 如果需要返回嵌入
            return scaled_logits, embed  # 返回缩放后的 logits 和嵌入

        return scaled_logits  # 返回缩放后的 logits

    def forward_with_neg_prompt(
        self,
        text_embed: torch.Tensor,
        neg_text_embed: torch.Tensor,
        cond_scale = 3.,  # 条件缩放因子
        return_embed = False,
        **kwargs
    ):
        neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播,使用负面文本嵌入
        pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs)  # 执行前向传播,使用正面文本嵌入

        logits = neg_logits + (pos_logits - neg_logits) * cond_scale  # 计算缩放后的 logits

        if return_embed:  # 如果需要返回嵌入
            return scaled_logits, embed  # 返回缩放后的 logits 和嵌入

        return scaled_logits  # 返回缩放后的 logits

    def forward(
        self,
        x,
        return_embed = False,
        return_logits = False,
        labels = None,
        ignore_index = 0,
        self_cond_embed = None,
        cond_drop_prob = 0.,
        conditioning_token_ids: Optional[torch.Tensor] = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None
        ):
        # 获取输入张量的设备、维度和长度
        device, b, n = x.device, *x.shape
        # 断言序列长度不超过self.seq_len

        # 准备文本数据

        # 断言texts和text_embeds中只有一个存在
        assert exists(texts) ^ exists(text_embeds)

        # 如果texts存在,则使用self.encode_text方法对texts进行编码得到text_embeds
        if exists(texts):
            text_embeds = self.encode_text(texts)

        # 对text_embeds进行线性变换得到context
        context = self.text_embed_proj(text_embeds)

        # 生成context_mask,用于指示哪些位置有文本数据
        context_mask = (text_embeds != 0).any(dim=-1)

        # 如果cond_drop_prob大于0,则进行条件性的dropout
        if cond_drop_prob > 0.:
            mask = prob_mask_like((b, 1), 1. - cond_drop_prob, device)
            context_mask = context_mask & mask

        # 如果conditioning_token_ids存在,则将其与context拼接起来
        if exists(conditioning_token_ids):
            conditioning_token_ids = rearrange(conditioning_token_ids, 'b ... -> b (...)')
            cond_token_emb = self.token_emb(conditioning_token_ids)
            context = torch.cat((context, cond_token_emb), dim=-2)
            context_mask = F.pad(context_mask, (0, conditioning_token_ids.shape[-1]), value=True)

        # 对输入的token进行嵌入
        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(n, device=device))

        # 如果self.self_cond为True,则对self_cond_embed进行初始化
        if self.self_cond:
            if not exists(self_cond_embed):
                self_cond_embed = torch.zeros_like(x)
            x = x + self.self_cond_to_init_embed(self_cond_embed)

        # 使用transformer_blocks进行编码
        embed = self.transformer_blocks(x, context=context, context_mask=context_mask)

        # 将编码结果转换为logits
        logits = self.to_logits(embed)

        # 如果return_embed为True,则返回logits和embed
        if return_embed:
            return logits, embed

        # 如果labels不存在,则返回logits
        if not exists(labels):
            return logits

        # 根据self.dim_out的值计算损失
        if self.dim_out == 1:
            loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
        else:
            loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index=ignore_index)

        # 如果return_logits为False,则返回损失
        if not return_logits:
            return loss

        # 返回损失和logits
        return loss, logits
# 定义一个自我批评的包装器类
class SelfCritic(nn.Module):
    # 初始化方法,接受一个网络对象作为参数
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.to_pred = nn.Linear(net.dim, 1)

    # 带有条件缩放的前向传播方法
    def forward_with_cond_scale(self, x, *args, **kwargs):
        _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed=True, **kwargs)
        return self.to_pred(embeds)

    # 带有负面提示的前向传播方法
    def forward_with_neg_prompt(self, x, *args, **kwargs):
        _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed=True, **kwargs)
        return self.to_pred(embeds)

    # 前向传播方法
    def forward(self, x, *args, labels=None, **kwargs):
        _, embeds = self.net(x, *args, return_embed=True, **kwargs)
        logits = self.to_pred(embeds)

        # 如果没有标签,则返回logits
        if not exists(labels):
            return logits

        # 重新排列logits并计算二元交叉熵损失
        logits = rearrange(logits, '... 1 -> ...')
        return F.binary_cross_entropy_with_logits(logits, labels)

# 特殊化的transformers类

# MaskGitTransformer类继承自Transformer类
class MaskGitTransformer(Transformer):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 断言'add_mask_id'不在关键字参数中
        assert 'add_mask_id' not in kwargs
        super().__init__(*args, add_mask_id=True, **kwargs)

# TokenCritic类继承自Transformer类
class TokenCritic(Transformer):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 断言'dim_out'不在关键字参数中
        assert 'dim_out' not in kwargs
        super().__init__(*args, dim_out=1, **kwargs)

# 无分类器指导函数

# 创建一个均匀分布的张量
def uniform(shape, min=0, max=1, device=None):
    return torch.zeros(shape, device=device).float().uniform_(0, 1)

# 根据概率创建掩码张量
def prob_mask_like(shape, prob, device=None):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return uniform(shape, device=device) < prob

# 采样辅助函数

# 计算张量的自然对数
def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))

# 生成Gumbel噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从Gumbel分布中采样
def gumbel_sample(t, temperature=1., dim=-1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)

# 保留top-k概率的值,其余设为负无穷
def top_k(logits, thres=0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = logits.topk(k, dim=-1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(2, ind, val)
    return probs

# 噪声调度

# 余弦调度函数
def cosine_schedule(t):
    return torch.cos(t * math.pi * 0.5)

# 主MaskGit类

# MaskGit类继承自nn.Module类
@beartype
class MaskGit(nn.Module):
    # 初始化方法,接受多个参数和关键字参数
    def __init__(
        self,
        image_size,
        transformer: MaskGitTransformer,
        noise_schedule: Callable = cosine_schedule,
        token_critic: Optional[TokenCritic] = None,
        self_token_critic=False,
        vae: Optional[VQGanVAE] = None,
        cond_vae: Optional[VQGanVAE] = None,
        cond_image_size=None,
        cond_drop_prob=0.5,
        self_cond_prob=0.9,
        no_mask_token_prob=0.,
        critic_loss_weight=1.
        ):
        # 调用父类的构造函数
        super().__init__()
        # 如果存在 VAE 模型,则复制一个用于评估的副本,否则设为 None
        self.vae = vae.copy_for_eval() if exists(vae) else None

        # 如果存在条件 VAE 模型,则将其设为评估模式,否则设为与 VAE 模型相同
        if exists(cond_vae):
            self.cond_vae = cond_vae.eval()
        else:
            self.cond_vae = self.vae

        # 断言条件:如果存在条件 VAE 模型,则条件图像大小必须指定
        assert not (exists(cond_vae) and not exists(cond_image_size)), 'cond_image_size must be specified if conditioning'

        # 初始化图像大小和条件图像大小等属性
        self.image_size = image_size
        self.cond_image_size = cond_image_size
        self.resize_image_for_cond_image = exists(cond_image_size)

        # 设置条件丢弃概率
        self.cond_drop_prob = cond_drop_prob

        # 设置变换器和是否自我条件
        self.transformer = transformer
        self.self_cond = transformer.self_cond
        # 断言条件:VAE 和条件 VAE 的码书大小必须与变换器的标记数相等
        assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size'

        # 设置掩码 ID 和噪声计划
        self.mask_id = transformer.mask_id
        self.noise_schedule = noise_schedule

        # 断言条件:自我令牌评论和令牌评论不能同时存在
        assert not (self_token_critic and exists(token_critic))
        self.token_critic = token_critic

        # 如果存在自我令牌评论,则将其设置为 SelfCritic 类的实例
        if self_token_critic:
            self.token_critic = SelfCritic(transformer)

        # 设置评论损失权重
        self.critic_loss_weight = critic_loss_weight

        # 设置自我条件概率
        self.self_cond_prob = self_cond_prob

        # 设置不掩码令牌的概率,以保持相同令牌,以便变换器在所有令牌上产生更好的嵌入,如原始 BERT 论文中所做
        # 可能需要用于自我条件
        self.no_mask_token_prob = no_mask_token_prob

    # 保存模型参数到指定路径
    def save(self, path):
        torch.save(self.state_dict(), path)

    # 从指定路径加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        state_dict = torch.load(str(path))
        self.load_state_dict(state_dict)

    # 生成方法,用于生成文本
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        texts: List[str],
        negative_texts: Optional[List[str]] = None,
        cond_images: Optional[torch.Tensor] = None,
        fmap_size = None,
        temperature = 1.,
        topk_filter_thres = 0.9,
        can_remask_prev_masked = False,
        force_not_use_token_critic = False,
        timesteps = 18,  # 理想的步数是 18,参考 maskgit 论文
        cond_scale = 3,
        critic_noise_scale = 1
    # 前向传播方法,用于模型推理
    def forward(
        self,
        images_or_ids: torch.Tensor,
        ignore_index = -1,
        cond_images: Optional[torch.Tensor] = None,
        cond_token_ids: Optional[torch.Tensor] = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None,
        cond_drop_prob = None,
        train_only_generator = False,
        sample_temperature = None
        ):
            # 如果需要进行标记化

            if images_or_ids.dtype == torch.float:
                assert exists(self.vae), 'vqgan vae must be passed in if training from raw images'
                assert all([height_or_width == self.image_size for height_or_width in images_or_ids.shape[-2:]]), 'the image you passed in is not of the correct dimensions'

                with torch.no_grad():
                    _, ids, _ = self.vae.encode(images_or_ids)
            else:
                assert not self.resize_image_for_cond_image, 'you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer'
                ids = images_or_ids

            # 处理指定的条件图像

            if self.resize_image_for_cond_image:
                cond_images_or_ids = F.interpolate(images_or_ids, self.cond_image_size, mode='nearest')

            # 获取一些基本变量

            ids = rearrange(ids, 'b ... -> b (...)')

            batch, seq_len, device, cond_drop_prob = *ids.shape, ids.device, default(cond_drop_prob, self.cond_drop_prob)

            # 如果需要对条件图像进行标记化

            assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids'

            if exists(cond_images):
                assert exists(self.cond_vae), 'cond vqgan vae must be passed in'
                assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]])

                with torch.no_grad():
                    _, cond_token_ids, _ = self.cond_vae.encode(cond_images)

            # 准备掩码

            rand_time = uniform((batch,), device=device)
            rand_mask_probs = self.noise_schedule(rand_time)
            num_token_masked = (seq_len * rand_mask_probs).round().clamp(min=1)

            mask_id = self.mask_id
            batch_randperm = torch.rand((batch, seq_len), device=device).argsort(dim=-1)
            mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')

            mask_id = self.transformer.mask_id
            labels = torch.where(mask, ids, ignore_index)

            if self.no_mask_token_prob > 0.:
                no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
                mask &= ~no_mask_mask

            x = torch.where(mask, mask_id, ids)

            # 获取文本嵌入

            if exists(texts):
                text_embeds = self.transformer.encode_text(texts)
                texts = None

            # 自我条件

            self_cond_embed = None

            if self.transformer.self_cond and random() < self.self_cond_prob:
                with torch.no_grad():
                    _, self_cond_embed = self.transformer(
                        x,
                        text_embeds=text_embeds,
                        conditioning_token_ids=cond_token_ids,
                        cond_drop_prob=0.,
                        return_embed=True
                    )

                    self_cond_embed.detach_()

            # 获取损失

            ce_loss, logits = self.transformer(
                x,
                text_embeds=text_embeds,
                self_cond_embed=self_cond_embed,
                conditioning_token_ids=cond_token_ids,
                labels=labels,
                cond_drop_prob=cond_drop_prob,
                ignore_index=ignore_index,
                return_logits=True
            )

            if not exists(self.token_critic) or train_only_generator:
                return ce_loss

            # 令牌评论家损失

            sampled_ids = gumbel_sample(logits, temperature=default(sample_temperature, random()))

            critic_input = torch.where(mask, sampled_ids, x)
            critic_labels = (ids != critic_input).float()

            bce_loss = self.token_critic(
                critic_input,
                text_embeds=text_embeds,
                conditioning_token_ids=cond_token_ids,
                labels=critic_labels,
                cond_drop_prob=cond_drop_prob
            )

            return ce_loss + self.critic_loss_weight * bce_loss
# 定义 Muse 类,继承自 nn.Module
@beartype
class Muse(nn.Module):
    # 初始化方法
    def __init__(
        self,
        base: MaskGit,  # 接收一个 MaskGit 类型的参数作为基础模型
        superres: MaskGit  # 接收一个 MaskGit 类型的参数作为超分辨率模型
    ):
        super().__init__()  # 调用父类的初始化方法
        self.base_maskgit = base.eval()  # 将传入的基础模型设为只读模式并赋值给实例变量

        assert superres.resize_image_for_cond_image  # 断言超分辨率模型具有 resize_image_for_cond_image 属性
        self.superres_maskgit = superres.eval()  # 将传入的超分辨率模型设为只读模式并赋值给实例变量

    # 前向传播方法,使用 torch.no_grad() 装饰器
    @torch.no_grad()
    def forward(
        self,
        texts: List[str],  # 接收一个字符串列表作为输入文本
        cond_scale = 3.,  # 设置默认条件尺度为 3
        temperature = 1.,  # 设置默认温度为 1
        timesteps = 18,  # 设置默认时间步数为 18
        superres_timesteps = None,  # 超分辨率时间步数,默认为 None
        return_lowres = False,  # 是否返回低分辨率图像,默认为 False
        return_pil_images = True  # 是否返回 PIL 图像,默认为 True
    ):
        # 使用基础模型生成低分辨率图像
        lowres_image = self.base_maskgit.generate(
            texts = texts,
            cond_scale = cond_scale,
            temperature = temperature,
            timesteps = timesteps
        )

        # 使用超分辨率模型生成高分辨率图像
        superres_image = self.superres_maskgit.generate(
            texts = texts,
            cond_scale = cond_scale,
            cond_images = lowres_image,
            temperature = temperature,
            timesteps = default(superres_timesteps, timesteps)  # 使用默认的超分辨率时间步数
        )
        
        # 如果需要返回 PIL 图像
        if return_pil_images:
            # 将低分辨率图像转换为 PIL 图像列表
            lowres_image = list(map(T.ToPILImage(), lowres_image))
            # 将高分辨率图像转换为 PIL 图像列表
            superres_image = list(map(T.ToPILImage(), superres_image))            

        # 如果不需要返回低分辨率图像,则返回高分辨率图像
        if not return_lowres:
            return superres_image

        # ��回高分辨率图像和低分辨率图像
        return superres_image, lowres_image

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\t5.py

# 导入日志、torch和transformers模块
import logging
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

# 设置transformers日志级别为error
transformers.logging.set_verbosity_error()

# 检查值是否存在的函数
def exists(val):
    return val is not None

# 配置参数
MAX_LENGTH = 256
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
T5_CONFIGS = {}

# 全局单例变量

# 获取指定模型的tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

# 获取指定模型的encoder模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定模型的encoder模型和tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免加载模型,仅获取维度
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 使用beartype装饰器,指定texts参数为字符串或字符串列表
@beartype
def t5_encode_text(
    texts: Union[str, List[str]],
    name = DEFAULT_T5_NAME,
    output_device = None
):
    if isinstance(texts, str):
        texts = [texts]

    # 获取指定模型的encoder模型和tokenizer
    t5, tokenizer = get_model_and_tokenizer(name)

    # 如果CUDA可用,则将模型移至CUDA
    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    # 对文本进行编码
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask.bool()
    encoded_text = encoded_text.masked_fill(~attn_mask[..., None], 0.)

    # 如果output_device存在,则将编码后的文本移至指定设备
    if not exists(output_device):
        return encoded_text

    encoded_text.to(output_device)
    return encoded_text

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\trainers.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 random 模块中导入 choice 函数
from random import choice
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 functools 模块中导入 partial 函数

# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype

# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.optim 模块中导入 Adam 类
from torch.optim import Adam
# 从 torch.utils.data 模块中导入 Dataset, DataLoader, random_split 类
from torch.utils.data import Dataset, DataLoader, random_split

# 从 torchvision.transforms 模块中导入 T 别名
import torchvision.transforms as T
# 从 torchvision.datasets 模块中导入 ImageFolder 类
from torchvision.datasets import ImageFolder
# 从 torchvision.utils 模块中导入 make_grid, save_image 函数

# 从 muse_maskgit_pytorch.vqgan_vae 模块中导入 VQGanVAE 类

# 从 einops 模块中导入 rearrange 函数

# 从 accelerate 模块中导入 Accelerator, DistributedType, DistributedDataParallelKwargs 类

# 从 ema_pytorch 模块中导入 EMA 类

# 从 PIL 模块中导入 Image, ImageFile 类
from PIL import Image, ImageFile
# 设置 ImageFile.LOAD_TRUNCATED_IMAGES 为 True

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 什么也不做
def noop(*args, **kwargs):
    pass

# 查找满足条件的元素的索引
def find_index(arr, cond):
    for ind, el in enumerate(arr):
        if cond(el):
            return ind
    return None

# 查找并弹出满足条件的元素
def find_and_pop(arr, cond, default = None):
    ind = find_index(arr, cond)

    if exists(ind):
        return arr.pop(ind)

    if callable(default):
        return default()

    return default

# 无限循环生成数据
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积更新日志
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 将输入转换为元组
def pair(val):
    return val if isinstance(val, tuple) else (val, val)

# 将图像转换为指定格式
def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 与图像相关的辅助函数和数据集

# 图像数据集类
class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')

        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# 主训练器类

# 使用 beartype 装饰器定义 VQGanVAETrainer 类
@beartype
class VQGanVAETrainer(nn.Module):
    def __init__(
        self,
        vae: VQGanVAE,
        *,
        folder,
        num_train_steps,
        batch_size,
        image_size,
        lr = 3e-4,
        grad_accum_every = 1,
        max_grad_norm = None,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 0,
        ema_update_every = 1,
        apply_grad_penalty_every = 4,
        accelerate_kwargs: dict = dict()
        ):
        # 调用父类的构造函数
        super().__init__()

        # 实例化加速器
        kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])

        # 查找并弹出 DistributedDataParallelKwargs 对象
        ddp_kwargs = find_and_pop(
            kwargs_handlers,
            lambda x: isinstance(x, DistributedDataParallelKwargs),
            partial(DistributedDataParallelKwargs, find_unused_parameters = True)
        )

        # 设置参数 find_unused_parameters 为 True
        ddp_kwargs.find_unused_parameters = True
        kwargs_handlers.append(ddp_kwargs)
        accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)

        # 实例化加速器对象
        self.accelerator = Accelerator(**accelerate_kwargs)

        # 设置 VAE 模型
        self.vae = vae

        # 设置训练参数
        self.register_buffer('steps', torch.Tensor([0]))
        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 获取所有参数和判别器参数
        all_parameters = set(vae.parameters())
        discr_parameters = set(vae.discr.parameters())
        vae_parameters = all_parameters - discr_parameters
        self.vae_parameters = vae_parameters

        # 设置优化器
        self.optim = Adam(vae_parameters, lr = lr)
        self.discr_optim = Adam(discr_parameters, lr = lr)
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # 创建数据集
        self.ds = ImageDataset(folder, image_size)

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 创建数据加载器
        self.dl = DataLoader(
            self.ds,
            batch_size = batch_size,
            shuffle = True
        )

        self.valid_dl = DataLoader(
            self.valid_ds,
            batch_size = batch_size,
            shuffle = True
        )

        # 使用加速器准备模型和数据加载器
        (
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.vae,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        )

        # 设置是否使用指数移动平均
        self.use_ema = use_ema

        # 如果使用指数移动平均,创建 EMA 对象并使用加速器准备
        if use_ema:
            self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
            self.ema_vae = self.accelerator.prepare(self.ema_vae)

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置应用梯度惩罚的频率
        self.apply_grad_penalty_every = apply_grad_penalty_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果结果文件夹不为空,询问是否清除之前的实验检查点和结果
        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)

    # 保存模型
    def save(self, path):
        # 如果不是本地主进程,则返回
        if not self.accelerator.is_local_main_process:
            return

        # 保存模型参数和优化器状态字典
        pkg = dict(
            model = self.accelerator.get_state_dict(self.vae),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path):
        # 将路径转换为Path对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型参数
        pkg = torch.load(path)

        # 获取未封装的VAE模型
        vae = self.accelerator.unwrap_model(self.vae)
        # 加载模型参数
        vae.load_state_dict(pkg['model'])

        # 加载优化器状态
        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 返回设备
    @property
    def device(self):
        return self.accelerator.device

    # 返回是否分布式
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 返回是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 返回是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process
    # 定义训练步骤函数
    def train_step(self):
        # 获取设备信息
        device = self.device

        # 获取当前步数
        steps = int(self.steps.item())
        # 判断是否需要应用梯度惩罚
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        # 设置 VAE 模型为训练模式
        self.vae.train()
        # 获取鉴别器模型
        discr = self.vae.module.discr if self.is_distributed else self.vae.discr
        # 如果使用指数移动平均模型,获取指数移动平均 VAE 模型
        if self.use_ema:
            ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae

        # 初始化日志字典
        logs = {}

        # 更新 VAE(生成器)

        # 根据梯度累积次数进行更新
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            img = next(self.dl_iter)
            img = img.to(device)

            # 使用自动混合精度计算损失
            with self.accelerator.autocast():
                # 计算 VAE 模型的损失
                loss = self.vae(
                    img,
                    add_gradient_penalty = apply_grad_penalty,
                    return_loss = True
                )

            # 反向传播
            self.accelerator.backward(loss / self.grad_accum_every)

            # 累积损失日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.vae.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 更新鉴别器

        if exists(discr):
            self.discr_optim.zero_grad()

            for _ in range(self.grad_accum_every):
                img = next(self.dl_iter)
                img = img.to(device)

                loss = self.vae(img, return_discr_loss = True)

                self.accelerator.backward(loss / self.grad_accum_every)

                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

            if exists(self.discr_max_grad_norm):
                self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm)

            self.discr_optim.step()

            # 记录日志
            self.print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

        # 更新指数移动平均生成器

        if self.use_ema:
            ema_vae.update()

        # 定期采样结果

        if not (steps % self.save_results_every):
            vaes_to_evaluate = ((self.vae, str(steps)),)

            if self.use_ema:
                vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate

            for model, filename in vaes_to_evaluate:
                model.eval()

                valid_data = next(self.valid_dl_iter)
                valid_data = valid_data.to(device)

                recons = model(valid_data, return_recons = True)

                # 保存图像网格

                imgs_and_recons = torch.stack((valid_data, recons), dim = 0)
                imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

                imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
                grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

                logs['reconstructions'] = grid

                save_image(grid, str(self.results_folder / f'{filename}.png'))

            self.print(f'{steps}: saving to {str(self.results_folder)}')

        # 定期保存模型
        self.accelerator.wait_for_everyone()
        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.accelerator.unwrap_model(self.vae).state_dict()
            model_path = str(self.results_folder / f'vae.{steps}.pt')
            self.accelerator.save(state_dict, model_path)

            if self.use_ema:
                ema_state_dict = self.accelerator.unwrap_model(self.ema_vae).state_dict()
                model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
                self.accelerator.save(ema_state_dict, model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs
    # 定义一个训练方法,接受一个日志函数作为参数,默认为一个空操作函数
    def train(self, log_fn = noop):
        # 获取 VAE 模型参数中的设备信息
        device = next(self.vae.parameters()).device

        # 当训练步数小于总训练步数时,执行训练步骤并记录日志
        while self.steps < self.num_train_steps:
            # 执行一次训练步骤,返回日志信息
            logs = self.train_step()
            # 调用日志函数记录日志信息
            log_fn(logs)

        # 打印训练完成信息
        self.print('training complete')

.\lucidrains\muse-maskgit-pytorch\muse_maskgit_pytorch\vqgan_vae.py

# 导入必要的模块
from pathlib import Path
import copy
import math
from math import sqrt
from functools import partial, wraps

# 导入自定义模块
from vector_quantize_pytorch import VectorQuantize as VQ, LFQ

# 导入 PyTorch 模块
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import grad as torch_grad

# 导入 torchvision 模块
import torchvision

# 导入 einops 模块
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 定义常量
MList = nn.ModuleList

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 移除 VGG 属性装饰器
def remove_vgg(fn):
    @wraps(fn)
    def inner(self, *args, **kwargs):
        has_vgg = hasattr(self, '_vgg')
        if has_vgg:
            vgg = self._vgg
            delattr(self, '_vgg')

        out = fn(self, *args, **kwargs)

        if has_vgg:
            self._vgg = vgg

        return out
    return inner

# 关键字参数辅助函数

# 选择并弹出指定键的值
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, string_input):
    return string_input.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀分组并修剪
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 张量辅助函数

# 对数函数
def log(t, eps = 1e-10):
    return torch.log(t + eps)

# 梯度惩罚函数
def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# Leaky ReLU 函数
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# 安全除法函数
def safe_div(numer, denom, eps = 1e-8):
    return numer / denom.clamp(min = eps)

# GAN 损失函数

# Hinge 判别器损失函数
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

# Hinge 生成器损失函数
def hinge_gen_loss(fake):
    return -fake.mean()

# BCE 判别器损失函数
def bce_discr_loss(fake, real):
    return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()

# BCE 生成器损失函数
def bce_gen_loss(fake):
    return -log(torch.sigmoid(fake)).mean()

# 计算损失对层的梯度
def grad_layer_wrt_loss(loss, layer):
    return torch_grad(
        outputs = loss,
        inputs = layer,
        grad_outputs = torch.ones_like(loss),
        retain_graph = True
    )[0].detach()

# VQGAN VAE

# 通道层归一化类
class LayerNormChan(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma

# 判别器类
class Discriminator(nn.Module):
    def __init__(
        self,
        dims,
        channels = 3,
        groups = 16,
        init_kernel_size = 5
    # 定义一个继承自 nn.Module 的类,用于构建一个简单的卷积神经网络
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 将输入维度按照前后两两配对,形成一个维度对的列表
        dim_pairs = zip(dims[:-1], dims[1:])

        # 初始化网络的第一层,包括一个卷积层和激活函数
        self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])

        # 遍历维度对列表,构建网络的中间层,每层包括卷积层、归一化层和激活函数
        for dim_in, dim_out in dim_pairs:
            self.layers.append(nn.Sequential(
                nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
                nn.GroupNorm(groups, dim_out),
                leaky_relu()
            ))

        # 获取最后一个维度
        dim = dims[-1]
        # 构建输出层,包括两个卷积层和激活函数,用于生成输出结果
        self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
            nn.Conv2d(dim, dim, 1),
            leaky_relu(),
            nn.Conv2d(dim, 1, 4)
        )

    # 定义前向传播方法,将输入数据通过网络层进行处理,得到输出结果
    def forward(self, x):
        # 遍历网络的每一层,将输入数据依次传递给每一层
        for net in self.layers:
            x = net(x)

        # 返回经过所有网络层处理后的输出结果
        return self.to_logits(x)
# 定义一个名为 ResnetEncDec 的类,用于实现 ResNet 编码器/解码器
class ResnetEncDec(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        channels = 3,
        layers = 4,
        layer_mults = None,
        num_resnet_blocks = 1,
        resnet_groups = 16,
        first_conv_kernel_size = 5
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言确保维度能够被 resnet_groups 整除
        assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'

        # 初始化 layers 属性
        self.layers = layers

        # 初始化 encoders 和 decoders 为 MList 类型的空列表
        self.encoders = MList([])
        self.decoders = MList([])

        # 如果未提供 layer_mults 参数,则使用默认值
        layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
        # 断言确保 layer_mults 的长度等于 layers
        assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'

        # 计算每一层的维度
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        # 记录编码后的维度
        self.encoded_dim = dims[-1]

        # 计算每一层的输入输出维度
        dim_pairs = zip(dims[:-1], dims[1:])

        # 定义辅助函数 append 和 prepend
        append = lambda arr, t: arr.append(t)
        prepend = lambda arr, t: arr.insert(0, t)

        # 如果 num_resnet_blocks 不是元组,则转换为元组
        if not isinstance(num_resnet_blocks, tuple):
            num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)

        # 断言确保 num_resnet_blocks 的长度等于 layers
        assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'

        # 遍历每一层,构建编码器和解码器
        for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks):
            # 添加卷积层和激活函数到编码器
            append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
            # 添加反卷积层和激活函数到解码器
            prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))

            # 添加 ResBlock 或 GLUResBlock 到编码器和解码器
            for _ in range(layer_num_resnet_blocks):
                append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
                prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))

        # 添加第一层卷积层到编码器
        prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
        # 添加最后一层卷积层到解码器
        append(self.decoders, nn.Conv2d(dim, channels, 1))

    # 获取编码后特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return image_size // (2 ** self.layers)

    # 返回最后一层解码器的权重
    @property
    def last_dec_layer(self):
        return self.decoders[-1].weight

    # 编码函数
    def encode(self, x):
        for enc in self.encoders:
            x = enc(x)
        return x

    # 解码函数
    def decode(self, x):
        for dec in self.decoders:
            x = dec(x)
        return x

# 定义 GLUResBlock 类,继承自 nn.Module
class GLUResBlock(nn.Module):
    # 初始化函数,接受通道数和组数参数
    def __init__(self, chan, groups = 16):
        # 调用父类的初始化函数
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan * 2, 3, padding = 1),
            nn.GLU(dim = 1),
            nn.GroupNorm(groups, chan),
            nn.Conv2d(chan, chan, 1)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x

# 定义 ResBlock 类,继承自 nn.Module
class ResBlock(nn.Module):
    # 初始化函数,接受通道数和组数参数
    def __init__(self, chan, groups = 16):
        # 调用父类的初始化函数
        super().__init__()
        # 定义网络结构
        self.net = nn.Sequential(
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 3, padding = 1),
            nn.GroupNorm(groups, chan),
            leaky_relu(),
            nn.Conv2d(chan, chan, 1)
        )

    # 前向传播函数
    def forward(self, x):
        return self.net(x) + x

# 定义 VQGanVAE 类,继承自 nn.Module
class VQGanVAE(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        channels = 3,  # 输入图像的通道数,默认为3
        layers = 4,  # 模型的层数,默认为4
        l2_recon_loss = False,  # 是否使用L2重构损失,默认为False
        use_hinge_loss = True,  # 是否使用hinge loss,默认为True
        vgg = None,  # VGG模型,默认为None
        lookup_free_quantization = True,  # 是否使用无查找表的量化,默认为True
        codebook_size = 65536,  # 量化码书的大小,默认为65536
        vq_kwargs: dict = dict(  # VQ模型的参数,默认为一些参数设置
            codebook_dim = 256,
            decay = 0.8,
            commitment_weight = 1.,
            kmeans_init = True,
            use_cosine_sim = True,
        ),
        lfq_kwargs: dict = dict(  # LFQ模型的参数,默认为一些参数设置
            diversity_gamma = 4.
        ),
        use_vgg_and_gan = True,  # 是否使用VGG和GAN,默认为True
        discr_layers = 4,  # 判别器的层数,默认为4
        **kwargs  # 其他参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数按照前缀分组并修剪
        vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
        encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)

        # 设置模型的一些属性
        self.channels = channels
        self.codebook_size = codebook_size
        self.dim_divisor = 2 ** layers

        enc_dec_klass = ResnetEncDec

        # 创建编码器解码器对象
        self.enc_dec = enc_dec_klass(
            dim = dim,
            channels = channels,
            layers = layers,
            **encdec_kwargs
        )

        self.lookup_free_quantization = lookup_free_quantization

        # 根据是否使用无查找表的量化选择量化器类型
        if lookup_free_quantization:
            self.quantizer = LFQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                **lfq_kwargs
            )
        else:
            self.quantizer = VQ(
                dim = self.enc_dec.encoded_dim,
                codebook_size = codebook_size,
                accept_image_fmap = True,
                **vq_kwargs
            )

        # 重构损失函数选择
        self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss

        # 如果是灰度图像,则关闭GAN和感知损失
        self._vgg = None
        self.discr = None
        self.use_vgg_and_gan = use_vgg_and_gan

        if not use_vgg_and_gan:
            return

        # 感知损失
        if exists(vgg):
            self._vgg = vgg

        # GAN相关损失
        layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
        layer_dims = [dim * mult for mult in layer_mults]
        dims = (dim, *layer_dims)

        self.discr = Discriminator(dims = dims, channels = channels)

        self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
        self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 获取VGG模型
    @property
    def vgg(self):
        if exists(self._vgg):
            return self._vgg

        vgg = torchvision.models.vgg16(pretrained = True)
        vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
        self._vgg = vgg.to(self.device)
        return self._vgg

    # 获取编码后的维度
    @property
    def encoded_dim(self):
        return self.enc_dec.encoded_dim

    # 获取编码特征图的大小
    def get_encoded_fmap_size(self, image_size):
        return self.enc_dec.get_encoded_fmap_size(image_size)

    # 复制模型用于评估
    def copy_for_eval(self):
        device = next(self.parameters()).device
        vae_copy = copy.deepcopy(self.cpu())

        if vae_copy.use_vgg_and_gan:
            del vae_copy.discr
            del vae_copy._vgg

        vae_copy.eval()
        return vae_copy.to(device)

    # 获取模型状态字典
    @remove_vgg
    def state_dict(self, *args, **kwargs):
        return super().state_dict(*args, **kwargs)

    # 加载模型状态字典
    @remove_vgg
    def load_state_dict(self, *args, **kwargs):
        return super().load_state_dict(*args, **kwargs)

    # 保存模型
    def save(self, path):
        torch.save(self.state_dict(), path)

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        state_dict = torch.load(str(path))
        self.load_state_dict(state_dict)

    # 编码函数
    def encode(self, fmap):
        fmap = self.enc_dec.encode(fmap)
        fmap, indices, vq_aux_loss = self.quantizer(fmap)
        return fmap, indices, vq_aux_loss
    # 从编码后的 ids 解码生成图像
    def decode_from_ids(self, ids):
        
        # 如果启用了自由量化查找,则将 ids 打包成字节流
        if self.lookup_free_quantization:
            ids, ps = pack([ids], 'b *')
            # 使用量化器将 ids 转换为 codes
            fmap = self.quantizer.indices_to_codes(ids)
            # 解码 codes 生成 fmap
            fmap, = unpack(fmap, ps, 'b * c')
        else:
            # 根据 ids 获取 codebook 中对应的 codes
            codes = self.codebook[ids]
            # 投影 codes 生成 fmap
            fmap = self.quantizer.project_out(codes)

        # 重新排列 fmap 的维度
        fmap = rearrange(fmap, 'b h w c -> b c h w')
        # 调用 decode 方法生成图像
        return self.decode(fmap)

    # 解码生成图像
    def decode(self, fmap):
        return self.enc_dec.decode(fmap)

    # 前向传播函数
    def forward(
        self,
        img,
        return_loss = False,
        return_discr_loss = False,
        return_recons = False,
        add_gradient_penalty = True
    ):
        # 获取图像的批次、通道数、高度、宽度和设备信息
        batch, channels, height, width, device = *img.shape, img.device

        # 检查高度和宽度是否能被 dim_divisor 整除
        for dim_name, size in (('height', height), ('width', width)):
            assert (size % self.dim_divisor) == 0, f'{dim_name} must be divisible by {self.dim_divisor}'

        # 检查通道数是否与 VQGanVAE 中设置的通道数相等
        assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'

        # 编码输入图像
        fmap, indices, commit_loss = self.encode(img)

        # 解码生成图像
        fmap = self.decode(fmap)

        # 如果不需要返回损失,则直接返回生成图像
        if not return_loss and not return_discr_loss:
            return fmap

        # 确保只返回自编码器损失或鉴别器损失
        assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'

        # 如果需要返回鉴别器损失
        if return_discr_loss:
            assert exists(self.discr), 'discriminator must exist to train it'

            # 分离 fmap 的梯度
            fmap.detach_()
            img.requires_grad_()

            # 获取 fmap 和输入图像的鉴别器 logits
            fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

            # 计算鉴别器损失
            discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

            # 添加梯度惩罚
            if add_gradient_penalty:
                gp = gradient_penalty(img, img_discr_logits)
                loss = discr_loss + gp

            # 如果需要返回重构图像,则返回损失和 fmap
            if return_recons:
                return loss, fmap

            return loss

        # 计算重构损失
        recon_loss = self.recon_loss_fn(fmap, img)

        # 如果不使用 VGG 和 GAN,则直接返回重构损失
        if not self.use_vgg_and_gan:
            if return_recons:
                return recon_loss, fmap

            return recon_loss

        # 计算感知损失
        img_vgg_input = img
        fmap_vgg_input = fmap

        if img.shape[1] == 1:
            # 处理灰度图像用于 VGG
            img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))

        img_vgg_feats = self.vgg(img_vgg_input)
        recon_vgg_feats = self.vgg(fmap_vgg_input)
        perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)

        # 生成器损失
        gen_loss = self.gen_loss(self.discr(fmap))

        # 计算自适应权重
        last_dec_layer = self.enc_dec.last_dec_layer

        norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
        norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

        adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
        adaptive_weight.clamp_(max = 1e4)

        # 组合损失
        loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss

        # 如果需要返回重构图像,则返回损失和 fmap
        if return_recons:
            return loss, fmap

        return loss