Lucidrains-系列项目源码解析-四十二-

156 阅读13分钟

Lucidrains 系列项目源码解析(四十二)

.\lucidrains\gigagan-pytorch\gigagan_pytorch\open_clip.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
import open_clip

from einops import rearrange

from beartype import beartype
from beartype.typing import List, Optional

# 检查值是否存在
def exists(val):
    return val is not None

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# OpenClipAdapter 类,继承自 nn.Module
class OpenClipAdapter(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        name = 'ViT-B/32',
        pretrained = 'laion400m_e32',
        tokenizer_name = 'ViT-B-32-quickgelu',
        eos_id = 49407
    ):
        super().__init__()

        # 创建 OpenCLIP 模型和预处理
        clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
        tokenizer = open_clip.get_tokenizer(tokenizer_name)

        self.clip = clip
        self.tokenizer = tokenizer
        self.eos_id = eos_id

        # 获取文本表示的钩子
        text_attention_final = self.find_layer('ln_final')
        self._dim_latent = text_attention_final.weight.shape[0]
        self.text_handle = text_attention_final.register_forward_hook(self._text_hook)

        # 获取图像表示的钩子
        self._dim_image_latent = self.find_layer('visual.ln_post').weight.shape[0]

        num_visual_layers = len(clip.visual.transformer.resblocks)
        self.image_handles = []

        for visual_layer in range(num_visual_layers):
            image_attention_final = self.find_layer(f'visual.transformer.resblocks.{visual_layer}')

            handle = image_attention_final.register_forward_hook(self._image_hook)
            self.image_handles.append(handle)

        # 归一化函数
        self.clip_normalize = preprocess.transforms[-1]
        self.cleared = False

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除钩子
    def clear(self):
        if self.cleared:
            return

        self.text_handle()
        self.image_handle()

    # 文本钩子函数
    def _text_hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    # 图像钩子函数
    def _image_hook(self, _, inputs, outputs):
        if not hasattr(self, 'image_encodings'):
            self.image_encodings = []

        self.image_encodings.append(outputs)

    # 获取潜在维度
    @property
    def dim_latent(self):
        return self._dim_latent

    # 获取图像尺寸
    @property
    def image_size(self):
        image_size = self.clip.visual.image_size
        if isinstance(image_size, tuple):
            return max(image_size)
        return image_size

    # 获取图像通道数
    @property
    def image_channels(self):
        return 3

    # 获取最大文本长度
    @property
    def max_text_len(self):
        return self.clip.positional_embedding.shape[0]

    # 嵌入文本
    @beartype
    def embed_texts(
        self,
        texts: List[str]
    ):
        ids = self.tokenizer(texts)
        ids = ids.to(self.device)
        ids = ids[..., :self.max_text_len]

        is_eos_id = (ids == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (ids != 0)
        assert not self.cleared

        text_embed = self.clip.encode_text(ids)
        text_encodings = self.text_encodings
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        del self.text_encodings
        return l2norm(text_embed.float()), text_encodings.float()

    # 嵌入图像
    def embed_images(self, images):
        if images.shape[-1] != self.image_size:
            images = F.interpolate(images, self.image_size)

        assert not self.cleared
        images = self.clip_normalize(images)
        image_embeds = self.clip.encode_image(images)

        image_encodings = rearrange(self.image_encodings, 'l n b d -> l b n d')
        del self.image_encodings

        return l2norm(image_embeds.float()), image_encodings.float()

    @beartype
    # 对比损失函数,用于计算文本和图像之间的相似性损失
    def contrastive_loss(
        self,
        images,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[torch.Tensor] = None
    ):
        # 断言文本或文本嵌入至少存在一个
        assert exists(texts) ^ exists(text_embeds)

        # 如果文本嵌入不存在,则通过文本获取文本嵌入
        if not exists(text_embeds):
            text_embeds, _ = self.embed_texts(texts)

        # 通过图像获取图像嵌入
        image_embeds, _ = self.embed_images(images)

        # 获取文本嵌入的数量
        n = text_embeds.shape[0]

        # 获取温度参数
        temperature = self.clip.logit_scale.exp()
        # 计算文本嵌入和图像嵌入之间的相似性
        sim = einsum('i d, j d -> i j', text_embeds, image_embeds) * temperature

        # 创建标签,用于计算交叉熵损失
        labels = torch.arange(n, device = sim.device)

        # 返回文本和图像之间的相似性损失
        return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2

.\lucidrains\gigagan-pytorch\gigagan_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 根据参数设置获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = True,
    group_wd_params = True,
    **kwargs
):
    # 根据参数是否需要梯度来过滤参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and wd > 0:
        # 将参数分为需要权重衰减和不需要权重衰减的两个列表
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 根据分组情况设置参数列表
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果不需要权重衰减,则使用 Adam 优化器
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 如果需要权重衰减,则使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\gigagan-pytorch\gigagan_pytorch\unet_upsampler.py

# 从 math 模块中导入 log2 函数
from math import log2
# 从 functools 模块中导入 partial 函数
from functools import partial

# 导入 torch 库
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数,以及 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 从 gigagan_pytorch 模块中导入各个自定义类和函数
from gigagan_pytorch.attend import Attend
from gigagan_pytorch.gigagan_pytorch import (
    BaseGenerator,
    StyleNetwork,
    AdaptiveConv2DMod,
    TextEncoder,
    CrossAttentionBlock,
    Upsample
)

# 从 beartype 库中导入 beartype 函数和相关类型注解
from beartype import beartype
from beartype.typing import Optional, List, Union, Dict, Iterable

# 辅助函数

# 判断变量是否存在
def exists(x):
    return x is not None

# 返回默认值函数
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 返回输入本身的函数
def identity(t, *args, **kwargs):
    return t

# 判断一个数是否为2的幂
def is_power_of_two(n):
    return log2(n).is_integer()

# 生成无限循环的迭代器
def null_iterator():
    while True:
        yield None

# 小型辅助模块

# 像素混洗上采样类
class PixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)

        # 创建卷积层对象
        conv = nn.Conv2d(dim, dim_out * 4, 1)
        self.init_conv_(conv)

        # 定义网络结构
        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

    # 初始化卷积层权重
    def init_conv_(self, conv):
        o, *rest_shape = conv.weight.shape
        conv_weight = torch.empty(o // 4, *rest_shape)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        return self.net(x)

# 下采样函数
def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

# RMS 归一化类
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    # 前向传播函数
    def forward(self, x):
        return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)

# 构建块模块

# 基础块类
class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        num_conv_kernels = 0
    ):
        super().__init__()
        self.proj = AdaptiveConv2DMod(dim, dim_out, kernel = 3, num_conv_kernels = num_conv_kernels)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    # 前向传播函数
    def forward(
        self,
        x,
        conv_mods_iter: Optional[Iterable] = None
    ):
        conv_mods_iter = default(conv_mods_iter, null_iterator())

        x = self.proj(
            x,
            mod = next(conv_mods_iter),
            kernel_mod = next(conv_mods_iter)
        )

        x = self.norm(x)
        x = self.act(x)
        return x

# ResNet 块类
class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        groups = 8,
        num_conv_kernels = 0,
        style_dims: List = []
    ):
        super().__init__()
        style_dims.extend([
            dim,
            num_conv_kernels,
            dim_out,
            num_conv_kernels
        ])

        self.block1 = Block(dim, dim_out, groups = groups, num_conv_kernels = num_conv_kernels)
        self.block2 = Block(dim_out, dim_out, groups = groups, num_conv_kernels = num_conv_kernels)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(
        self,
        x,
        conv_mods_iter: Optional[Iterable] = None
    ):
        h = self.block1(x, conv_mods_iter = conv_mods_iter)
        h = self.block2(h, conv_mods_iter = conv_mods_iter)

        return h + self.res_conv(x)

# 线性注意力类
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32
    # 初始化函数,设置缩放因子和头数
    def __init__(
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        # 初始化 RMSNorm 层
        self.norm = RMSNorm(dim)
        # 创建卷积层,用于计算查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        # 创建输出层,包含卷积层和 RMSNorm 层
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )

    # 前向传播函数
    def forward(self, x):
        b, c, h, w = x.shape

        # 对输入进行归一化处理
        x = self.norm(x)

        # 将输入通过卷积层得到查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        # 对查询和键进行 softmax 处理
        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        # 对查询进行缩放
        q = q * self.scale

        # 计算上下文信息
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        # 计算输出
        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        flash = False
    ):
        # 初始化注意力机制模块
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim)
        # 注意力计算
        self.attend = Attend(flash = flash)

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        # 输出转换
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape

        # 归一化输入
        x = self.norm(x)

        # 将输入转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)

        # 注意力计算
        out = self.attend(q, k, v)

        # 重排输出形状
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# feedforward

def FeedForward(dim, mult = 4):
    # 前馈神经网络
    return nn.Sequential(
        RMSNorm(dim),
        nn.Conv2d(dim, dim * mult, 1),
        nn.GELU(),
        nn.Conv2d(dim * mult, dim, 1)
    )

# transformers

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 1,
        flash_attn = True,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 构建多层Transformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash_attn),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

class LinearTransformer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 1,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 构建多层LinearTransformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

# model

class UnetUpsampler(BaseGenerator):

    @beartype
    def __init__(
        self,
        dim,
        *,
        image_size,
        input_image_size,
        init_dim = None,
        out_dim = None,
        text_encoder: Optional[Union[TextEncoder, Dict]] = None,
        style_network: Optional[Union[StyleNetwork, Dict]] = None,
        style_network_dim = None,
        dim_mults = (1, 2, 4, 8, 16),
        channels = 3,
        resnet_block_groups = 8,
        full_attn = (False, False, False, True, True),
        cross_attn = (False, False, False, True, True),
        flash_attn = True,
        self_attn_dim_head = 64,
        self_attn_heads = 8,
        self_attn_dot_product = True,
        self_attn_ff_mult = 4,
        attn_depths = (1, 1, 1, 1, 1),
        cross_attn_dim_head = 64,
        cross_attn_heads = 8,
        cross_ff_mult = 4,
        mid_attn_depth = 1,
        num_conv_kernels = 2,
        resize_mode = 'bilinear',
        unconditional = True,
        skip_connect_scale = None
    ):
        # 初始化UnetUpsampler模型
        super().__init__()

    @property
    def allowable_rgb_resolutions(self):
        # 计算允许的RGB分辨率
        input_res_base = int(log2(self.input_image_size))
        output_res_base = int(log2(self.image_size))
        allowed_rgb_res_base = list(range(input_res_base, output_res_base))
        return [*map(lambda p: 2 ** p, allowed_rgb_res_base)]

    @property
    def device(self):
        # 获取模型所在设备
        return next(self.parameters()).device

    @property
    def total_params(self):
        # 计算模型总参数数量
        return sum([p.numel() for p in self.parameters()])

    def resize_image_to(self, x, size):
        # 调整输入图像大小
        return F.interpolate(x, (size, size), mode = self.resize_mode)
    # 定义一个前向传播函数,接受低分辨率图像、风格、噪声、文本等参数,并返回RGB图像
    def forward(
        self,
        lowres_image,
        styles = None,
        noise = None,
        texts: Optional[List[str]] = None,
        global_text_tokens = None,
        fine_text_tokens = None,
        text_mask = None,
        return_all_rgbs = False,
        replace_rgb_with_input_lowres_image = True   # discriminator should also receive the low resolution image the upsampler sees
    ):
        # 将输入的低分辨率图像赋值给x
        x = lowres_image
        # 获取x的形状
        shape = x.shape
        # 获取批处理大小
        batch_size = shape[0]

        # 断言x的最后两个维度与输入图像大小相同
        assert shape[-2:] == ((self.input_image_size,) * 2)

        # 处理文本编码
        # 需要全局文本标记自适应选择主要贡献中的内核
        # 需要细节文本标记进行交叉注意力
        if not self.unconditional:
            if exists(texts):
                assert exists(self.text_encoder)
                global_text_tokens, fine_text_tokens, text_mask = self.text_encoder(texts)
            else:
                assert all([*map(exists, (global_text_tokens, fine_text_tokens, text_mask))])
        else:
            assert not any([*map(exists, (texts, global_text_tokens, fine_text_tokens))])

        # 风格
        if not exists(styles):
            assert exists(self.style_network)

            noise = default(noise, torch.randn((batch_size, self.style_network.dim), device = self.device))
            styles = self.style_network(noise, global_text_tokens)

        # 将风格投影到卷积调制
        conv_mods = self.style_to_conv_modulations(styles)
        conv_mods = conv_mods.split(self.style_embed_split_dims, dim = -1)
        conv_mods = iter(conv_mods)

        # 初始卷积
        x = self.init_conv(x)

        h = []

        # 下采样阶段
        for block1, block2, cross_attn, attn, downsample in self.downs:
            x = block1(x, conv_mods_iter = conv_mods)
            h.append(x)

            x = block2(x, conv_mods_iter = conv_mods)

            x = attn(x)

            if exists(cross_attn):
                x = cross_attn(x, context = fine_text_tokens, mask = text_mask)

            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, conv_mods_iter = conv_mods)
        x = self.mid_attn(x)
        x = self.mid_block2(x, conv_mods_iter = conv_mods)

        # rgbs
        rgbs = []

        init_rgb_shape = list(x.shape)
        init_rgb_shape[1] = self.channels

        rgb = self.mid_to_rgb(x)
        rgbs.append(rgb)

        # 上采样阶段
        for upsample, upsample_rgb, to_rgb, block1, block2, cross_attn, attn in self.ups:

            x = upsample(x)
            rgb = upsample_rgb(rgb)

            res1 = h.pop() * self.skip_connect_scale
            res2 = h.pop() * self.skip_connect_scale

            fmap_size = x.shape[-1]
            residual_fmap_size = res1.shape[-1]

            if residual_fmap_size != fmap_size:
                res1 = self.resize_image_to(res1, fmap_size)
                res2 = self.resize_image_to(res2, fmap_size)

            x = torch.cat((x, res1), dim = 1)
            x = block1(x, conv_mods_iter = conv_mods)

            x = torch.cat((x, res2), dim = 1)
            x = block2(x, conv_mods_iter = conv_mods)

            if exists(cross_attn):
                x = cross_attn(x, context = fine_text_tokens, mask = text_mask)

            x = attn(x)

            rgb = rgb + to_rgb(x)
            rgbs.append(rgb)

        x = self.final_res_block(x, conv_mods_iter = conv_mods)

        assert len([*conv_mods]) == 0

        rgb = rgb + self.final_to_rgb(x)

        if not return_all_rgbs:
            return rgb

        # 仅保留那些特征图大于要上采样的输入图像的rgbs
        rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs))

        # 并将原始输入图像作为最小的rgb返回
        rgbs = [lowres_image, *rgbs]

        return rgb, rgbs

.\lucidrains\gigagan-pytorch\gigagan_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '0.2.20'
__version__ = '0.2.20'

.\lucidrains\gigagan-pytorch\gigagan_pytorch\__init__.py

# 从 gigagan_pytorch 模块中导入 GigaGAN 相关类
from gigagan_pytorch.gigagan_pytorch import (
    GigaGAN,
    Generator,
    Discriminator,
    VisionAidedDiscriminator,
    AdaptiveConv2DMod,
    StyleNetwork,
    TextEncoder
)

# 从 gigagan_pytorch 模块中导入 UnetUpsampler 类
from gigagan_pytorch.unet_upsampler import UnetUpsampler

# 从 gigagan_pytorch 模块中导入数据相关类
from gigagan_pytorch.data import (
    ImageDataset,
    TextImageDataset,
    MockTextImageDataset
)

# 定义 __all__ 列表,包含需要导出的类
__all__ = [
    GigaGAN,
    Generator,
    Discriminator,
    VisionAidedDiscriminator,
    AdaptiveConv2DMod,
    StyleNetwork,
    UnetUpsampler,
    TextEncoder,
    ImageDataset,
    TextImageDataset,
    MockTextImageDataset
]

GigaGAN - Pytorch

Implementation of GigaGAN (project page), new SOTA GAN out of Adobe.

I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)

It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Appreciation

  • StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • 🤗 Huggingface for their accelerate library

  • All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models

  • Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!

  • @CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler!

  • Keerth for the code review and pointing out some discrepancies with the paper!

Install

$ pip install gigagan-pytorch

Usage

Simple unconditional GAN, for starters

import torch

from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    generator = dict(
        dim_capacity = 8,
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        image_size = 256,
        dim_max = 512,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    amp = True
).cuda()

# dataset

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

# you must then set the dataloader for the GAN before training

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

images = gan.generate(batch_size = 4) # (4, 3, 256, 256)

For unconditional Unet Upsampler

import torch
from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    train_upsampler = True,     # set this to True
    generator = dict(
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        dim = 32,
        image_size = 256,
        input_image_size = 64,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        multiscale_input_resolutions = (128,),
        unconditional = True
    ),
    amp = True
).cuda()

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

lowres = torch.randn(1, 3, 64, 64).cuda()

images = gan.generate(lowres) # (1, 3, 256, 256)

Losses

  • G - Generator
  • MSG - Multiscale Generator
  • D - Discriminator
  • MSD - Multiscale Discriminator
  • GP - Gradient Penalty
  • SSL - Auxiliary Reconstruction in Discriminator (from Lightweight GAN)
  • VD - Vision-aided Discriminator
  • VG - Vision-aided Generator
  • CL - Generator Constrastive Loss
  • MAL - Matching Aware Loss

A healthy run would have G, MSG, D, MSD with values hovering between 0 to 10, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.

GP and SSL should be pushed towards 0. GP can occasionally spike; I like to imagine it as the networks undergoing some epiphany

Multi-GPU Training

The GigaGAN class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their accelerate CLI

At the project root directory, where the training script is, run

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Todo

  • make sure it can be trained unconditionally

  • read the relevant papers and knock out all 3 auxiliary losses

    • matching aware loss
    • clip loss
    • vision-aided discriminator loss
    • add reconstruction losses on arbitrary stages in the discriminator (lightweight gan)
    • figure out how the random projections are used from projected-gan
    • vision aided discriminator needs to extract N layers from the vision model in CLIP
    • figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case
  • unet upsampler

    • add adaptive conv
    • modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in
    • do pixel shuffle upsamples for unet
  • get a code review for the multi-scale inputs and outputs, as the paper was a bit vague

  • add upsampling network architecture

  • make unconditional work for both base generator and upsampler

  • make text conditioned training work for both base and upsampler

  • make recon more efficient by random sampling patches

  • make sure generator and discriminator can also accept pre-encoded CLIP text encodings

  • do a review of the auxiliary losses

    • add contrastive loss for generator
    • add vision aided loss
    • add gradient penalty for vision aided discr - make optional
    • add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader)
    • make sure gradient accumulation works with matching aware loss
    • matching awareness loss runs and is stable
    • vision aided trains
  • add some differentiable augmentations, proven technique from the old GAN days

    • remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales
    • add horizontal flip for starters
  • move all modulation projections into the adaptive conv2d class

  • add accelerate

    • works single machine
    • works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch
    • make sure it works multi-GPU for one machine
    • have someone else try multiple machines
  • clip should be optional for all modules, and managed by GigaGAN, with text -> text embeds processed once

  • add ability to select a random subset from multiscale dimension, for efficiency

  • port over CLI from lightweight|stylegan2-pytorch

  • hook up laion dataset for text-image

Citations

@misc{https://doi.org/10.48550/arxiv.2303.05511,
    url     = {https://arxiv.org/abs/2303.05511},
    author  = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},  
    title   = {Scaling up GANs for Text-to-Image Synthesis},
    publisher = {arXiv},
    year    = {2023},
    copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{Liu2021TowardsFA,
    title   = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
    author  = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2101.04775}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Karras2020ada,
    title     = {Training Generative Adversarial Networks with Limited Data},
    author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    booktitle = {Proc. NeurIPS},
    year      = {2020}
}

.\lucidrains\gigagan-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('gigagan_pytorch/version.py').read())

# 设置包的元数据
setup(
  name = 'gigagan-pytorch', # 包名
  packages = find_packages(exclude=[]), # 查找包
  version = __version__, # 版本号
  license='MIT', # 许可证
  description = 'GigaGAN - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/ETSformer-pytorch', # URL
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'generative adversarial networks'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'beartype',
    'einops>=0.6',
    'ema-pytorch',
    'kornia',
    'numerize',
    'open-clip-torch>=2.0.0,<3.0.0',
    'pillow',
    'torch>=1.6',
    'torchvision',
    'tqdm'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\global-self-attention-network\gsa_pytorch\gsa_pytorch.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数
from einops import rearrange
# 从 inspect 中导入 isfunction 函数

# 辅助函数

# 如果 val 存在则返回 val,否则返回 d()
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 判断 val 是否存在
def exists(val):
    return val is not None

# 计算重新索引张量
def calc_reindexing_tensor(l, L, device):
    """
    Appendix B - (5)
    """
    # 创建 x 张量
    x = torch.arange(l, device = device)[:, None, None]
    # 创建 i 张量
    i = torch.arange(l, device = device)[None, :, None]
    # 创建 r 张量
    r = torch.arange(-(L - 1), L, device = device)[None, None, :]
    # 创建 mask 张量
    mask = ((i - x) == r) & ((i - x).abs() <= L)
    return mask.float()

# 类

# GSA 类
class GSA(nn.Module):
    # 初始化函数
    def __init__(self, dim, *, rel_pos_length = None, dim_out = None, heads = 8, dim_key = 64, norm_queries = False, batch_norm = True):
        super().__init__()
        dim_out = default(dim_out, dim)
        dim_hidden = dim_key * heads

        self.heads = heads
        self.dim_out = dim_out
        self.rel_pos_length = rel_pos_length
        self.norm_queries = norm_queries

        # 创建卷积层,用于将输入转换为查询、键和值
        self.to_qkv = nn.Conv2d(dim, dim_hidden * 3, 1, bias = False)
        # 创建卷积层,用于将隐藏层转换为输出维度
        self.to_out = nn.Conv2d(dim_hidden, dim_out, 1)

        self.rel_pos_length = rel_pos_length
        if exists(rel_pos_length):
            num_rel_shifts = 2 * rel_pos_length - 1
            self.norm = nn.BatchNorm2d(dim_key) if batch_norm else None
            self.rel_rows = nn.Parameter(torch.randn(num_rel_shifts, dim_key))
            self.rel_columns = nn.Parameter(torch.randn(num_rel_shifts, dim_key))

    # 前向传播函数
    def forward(self, img):
        # 获取输入张量的形状信息
        b, c, x, y, h, c_out, L, device = *img.shape, self.heads, self.dim_out, self.rel_pos_length, img.device

        # 将输入张量通过 to_qkv 卷积层得到查询、键和值
        qkv = self.to_qkv(img).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c (x y)', h = h), qkv)

        # 对键进行 softmax 操作
        k = k.softmax(dim = -1)
        # 计算上下文信息
        context = einsum('ndm,nem->nde', k, v)

        # 如果需要对查询进行归一化,则进行 softmax 操作
        content_q = q if not self.norm_queries else q.softmax(dim=-2)

        # 计算内容输出
        content_out = einsum('nde,ndm->nem', context, content_q)
        content_out = rearrange(content_out, 'n d (x y) -> n d x y', x = x, y = y)

        # 根据附录 B (6) - (8) 中的数学实现细节进行处理
        if exists(self.rel_pos_length):
            q, v = map(lambda t: rearrange(t, 'n c (x y) -> n c x y', x = x, y = y), (q, v))

            Ix = calc_reindexing_tensor(x, L, device)
            Px = einsum('xir,rd->xid', Ix, self.rel_rows)
            Sx = einsum('ndxy,xid->nixy', q, Px)
            Yh = einsum('nixy,neiy->nexy', Sx, v)

            if exists(self.norm):
                Yh = self.norm(Yh)

            Iy = calc_reindexing_tensor(y, L, device)
            Py = einsum('yir,rd->yid', Iy, self.rel_columns)
            Sy = einsum('ndxy,yid->nixy', q, Py)
            rel_pos_out = einsum('nixy,nexi->nexy', Sy, Yh)

            content_out = content_out + rel_pos_out.contiguous()

        content_out = rearrange(content_out, '(b h) c x y -> b (h c) x y', h = h)
        return self.to_out(content_out)

.\lucidrains\global-self-attention-network\gsa_pytorch\__init__.py

# 从 gsa_pytorch 模块中导入 GSA 类
from gsa_pytorch.gsa_pytorch import GSA

Global Self-attention Network

An implementation of Global Self-Attention Network, which proposes an all-attention vision backbone that achieves better results than convolutions with less parameters and compute.

They use a previously discovered linear attention variant with a small modification for further gains (no normalization of the queries), paired with relative positional attention, computed axially for efficiency.

The result is an extremely simple circuit composed of 8 einsums, 1 softmax, and normalization.

Install

$ pip install gsa-pytorch

Usage

import torch
from gsa_pytorch import GSA

gsa = GSA(
    dim = 3,
    dim_out = 64,
    dim_key = 32,
    heads = 8,
    rel_pos_length = 256  # in paper, set to max(height, width). you can also turn this off by omitting this line
)

x = torch.randn(1, 3, 256, 256)
gsa(x) # (1, 64, 256, 256)

Citations

@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}

.\lucidrains\global-self-attention-network\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'gsa-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.2.2', # 版本号
  license='MIT', # 许可证
  description = 'Global Self-attention Network (GSA) - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/global-self-attention-network', # 项目链接
  keywords = [
    'artificial intelligence', # 关键词:人工智能
    'attention mechanism', # 关键词:注意力机制
    'image recognition' # 关键词:图像识别
  ],
  install_requires=[
    'torch>=1.6', # 安装所需的依赖项:torch 版本大于等于 1.6
    'einops>=0.3' # 安装所需的依赖项:einops 版本大于等于 0.3
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 分类器:开发状态为 Beta
    'Intended Audience :: Developers', # 分类器:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类器:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License', # 分类器:许可证为 MIT
    'Programming Language :: Python :: 3.6', # 分类器:编程语言为 Python 3.6
  ],
)

.\lucidrains\glom-pytorch\glom_pytorch\glom_pytorch.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn 和 functional 模块
import torch.nn.functional as F
# 从 torch 模块中导入 einsum 函数
from torch import nn, einsum
# 从 einops 模块中导入 rearrange 和 repeat 函数,以及 torch 模块中的 Rearrange 类
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 常量定义

# 定义 TOKEN_ATTEND_SELF_VALUE 常量为 -5e-4
TOKEN_ATTEND_SELF_VALUE = -5e-4

# 辅助函数

# 定义 exists 函数,判断值是否存在
def exists(val):
    return val is not None

# 定义 default 函数,如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 类定义

# 定义 GroupedFeedForward 类
class GroupedFeedForward(nn.Module):
    def __init__(self, *, dim, groups, mult = 4):
        super().__init__()
        total_dim = dim * groups # 计算总维度
        # 定义神经网络结构
        self.net = nn.Sequential(
            Rearrange('b n l d -> b (l d) n'),
            nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
            nn.GELU(),
            nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
            Rearrange('b (l d) n -> b n l d', l = groups)
        )

    # 前向传播函数
    def forward(self, levels):
        return self.net(levels)

# 定义 ConsensusAttention 类
class ConsensusAttention(nn.Module):
    def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):
        super().__init__()
        self.attend_self = attend_self
        self.local_consensus_radius = local_consensus_radius

        # 如果存在局部一致性半径
        if self.local_consensus_radius > 0:
            # 生成坐标网格
            coors = torch.stack(torch.meshgrid(
                torch.arange(num_patches_side),
                torch.arange(num_patches_side)
            )).float()

            coors = rearrange(coors, 'c h w -> (h w) c')
            dist = torch.cdist(coors, coors)
            mask_non_local = dist > self.local_consensus_radius
            mask_non_local = rearrange(mask_non_local, 'i j -> () i j')
            self.register_buffer('non_local_mask', mask_non_local)

    # 前向传播函数
    def forward(self, levels):
        _, n, _, d, device = *levels.shape, levels.device
        q, k, v = levels, F.normalize(levels, dim = -1), levels

        sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)

        if not self.attend_self:
            self_mask = torch.eye(n, device = device, dtype = torch.bool)
            self_mask = rearrange(self_mask, 'i j -> () () i j')
            sim.masked_fill_(self_mask, TOKEN_ATTEND_SELF_VALUE)

        if self.local_consensus_radius > 0:
            max_neg_value = -torch.finfo(sim.dtype).max
            sim.masked_fill_(self.non_local_mask, max_neg_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b l i j, b j l d -> b i l d', attn, levels)
        return out

# 主类定义

# 定义 Glom 类
class Glom(nn.Module):
    def __init__(
        self,
        *,
        dim = 512,
        levels = 6,
        image_size = 224,
        patch_size = 14,
        consensus_self = False,
        local_consensus_radius = 0
    ):
        super().__init__()
        # 计算每个边上的补丁数量
        num_patches_side = (image_size // patch_size)
        num_patches =  num_patches_side ** 2
        self.levels = levels

        # 图像转换为标记的神经网络结构
        self.image_to_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_size ** 2 * 3, dim)
        )
        self.pos_emb = nn.Embedding(num_patches, dim)

        # 列的所有级别的初始嵌入
        self.init_levels = nn.Parameter(torch.randn(levels, dim))

        # 自下而上和自上而下
        self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
        self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)

        # 一致性注意力
        self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)
    # 定义前向传播函数,接受输入图像和可选参数,返回处理后的结果
    def forward(self, img, iters = None, levels = None, return_all = False):
        # 获取输入图像的形状和设备信息
        b, device = img.shape[0], img.device
        # 如果未提供迭代次数,则设置为默认值(层级数的两倍),以便信息在上下传播时能够传播
        iters = default(iters, self.levels * 2)

        # 将图像转换为 tokens
        tokens = self.image_to_tokens(img)
        n = tokens.shape[1]

        # 生成位置编码
        pos_embs = self.pos_emb(torch.arange(n, device = device))
        pos_embs = rearrange(pos_embs, 'n d -> () n () d')

        # 初始化底层 tokens
        bottom_level = tokens
        bottom_level = rearrange(bottom_level, 'b n d -> b n () d')

        # 如果未提供层级信息,则使用初始层级信息
        if not exists(levels):
            levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)

        # 存储每次迭代后的隐藏层信息
        hiddens = [levels]

        # 初始化每个层级的贡献次数
        num_contributions = torch.empty(self.levels, device = device).fill_(4)
        num_contributions[-1] = 3  # 顶层不会得到来自顶部的贡献,因此需要考虑这一点在计算加权平均时

        # 迭代处理
        for _ in range(iters):
            # 将原始输入附加到最底层,用于自底向上
            levels_with_input = torch.cat((bottom_level, levels), dim = -2)

            # 底部向上处理
            bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])

            # 顶部向下处理,加上位置编码
            top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs)
            top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)

            # 计算共识信息
            consensus = self.attention(levels)

            # 计算加权平均值
            levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0)
            levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')

            # 更新层级信息,用于下一次迭代
            levels = levels_mean
            hiddens.append(levels)

        # 如果需要返回所有隐藏层信息,则返回整个列表
        if return_all:
            return torch.stack(hiddens)

        # 否则,只返回最终的层级信息
        return levels

.\lucidrains\glom-pytorch\glom_pytorch\__init__.py

# 从 glom_pytorch 模块中导入 Glom 类
from glom_pytorch.glom_pytorch import Glom

GLOM - Pytorch

An implementation of Glom, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns) for learning emergent part-whole heirarchies from data.

Yannic Kilcher's video was instrumental in helping me to understand this paper

Install

$ pip install glom-pytorch

Usage

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from glom_pytorch import Glom

model = Glom(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from glom_pytorch import Glom

model = Glom(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

Appreciation

Thanks goes out to Cfoster0 for reviewing the code

Todo

  • contrastive / consistency regularization of top-ish levels

Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\glom-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'glom-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.14', # 版本号
  license='MIT', # 许可证
  description = 'Glom - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/glom-pytorch', # 项目链接
  keywords = [
    'artificial intelligence', # 关键词
    'deep learning'
  ],
  install_requires=[
    'einops>=0.3', # 安装所需的依赖包
    'torch>=1.6'
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 分类器
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\gradnorm_pytorch.py

# 导入必要的库
from functools import cache, partial
import torch
import torch.distributed as dist
from torch.autograd import grad
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList, Parameter
from einops import rearrange, repeat
from accelerate import Accelerator
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, Union, List, Dict, Tuple, NamedTuple

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 张量辅助函数

# 计算张量的 L1 范数
def l1norm(t, dim = -1):
    return F.normalize(t, p = 1, dim = dim)

# 分布式计算辅助函数

# 判断是否处于分布式环境
@cache
def is_distributed():
    return dist.is_initialized() and dist.get_world_size() > 1

# 如果处于分布式环境,则计算张量的均值
def maybe_distributed_mean(t):
    if not is_distributed():
        return t

    dist.all_reduce(t)
    t = t / dist.get_world_size()
    return t

# 主类

class GradNormLossWeighter(Module):
    @beartype
    def __init__(
        self,
        *,
        num_losses: Optional[int] = None,
        loss_weights: Optional[Union[
            List[float],
            Tensor
        ]] = None,
        loss_names: Optional[Tuple[str, ...]] = None,
        learning_rate = 1e-4,
        restoring_force_alpha = 0.,
        grad_norm_parameters: Optional[Parameter] = None,
        accelerator: Optional[Accelerator] = None,
        frozen = False,
        initial_losses_decay = 1.,
        update_after_step = 0.,
        update_every = 1.
    ):
        super().__init__()
        assert exists(num_losses) or exists(loss_weights)

        if exists(loss_weights):
            if isinstance(loss_weights, list):
                loss_weights = torch.tensor(loss_weights)

            num_losses = default(num_losses, loss_weights.numel())
        else:
            loss_weights = torch.ones((num_losses,), dtype = torch.float32)

        assert len(loss_weights) == num_losses
        assert num_losses > 1, 'only makes sense if you have multiple losses'
        assert loss_weights.ndim == 1, 'loss weights must be 1 dimensional'

        self.accelerator = accelerator
        self.num_losses = num_losses
        self.frozen = frozen

        self.loss_names = loss_names
        assert not exists(loss_names) or len(loss_names) == num_losses

        assert restoring_force_alpha >= 0.

        self.alpha = restoring_force_alpha
        self.has_restoring_force = self.alpha > 0

        self._grad_norm_parameters = [grad_norm_parameters] # hack

        # 损失权重,可以是学习得到的或静态的

        self.register_buffer('loss_weights', loss_weights)

        self.learning_rate = learning_rate

        # 初始损失
        # 如果初始损失衰减设置为小于1,则会对初始损失进行 EMA 平滑处理

        assert 0 <= initial_losses_decay <= 1.
        self.initial_losses_decay = initial_losses_decay

        self.register_buffer('initial_losses', torch.zeros(num_losses))

        # 用于在最后重新归一化损失权重

        self.register_buffer('loss_weights_sum', self.loss_weights.sum())

        # 用于梯度累积

        self.register_buffer('loss_weights_grad', torch.zeros_like(loss_weights), persistent = False)

        # 步数,用于可能的调度等

        self.register_buffer('step', torch.tensor(0.))

        # 可以较少频繁更新,以节省计算资源

        self.update_after_step = update_after_step
        self.update_every = update_every

        self.register_buffer('initted', torch.tensor(False))

    @property
    def grad_norm_parameters(self):
        return self._grad_norm_parameters[0]

    def backward(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @beartype
    # 定义一个 forward 方法,用于前向传播
    def forward(
        self,
        losses: Union[
            Dict[str, Tensor],    # 损失值可以是字典类型,键为字符串,值为张量
            List[Tensor],         # 损失值可以是张量列表
            Tuple[Tensor],        # 损失值可以是元组中的张量
            Tensor                # 损失值可以是单个张量
        ],
        activations: Optional[Tensor] = None,     # 激活值,默认为 None,在论文中,他们使用了从骨干层次的倒数第二个参数的梯度范数。但这也可以是激活值(例如,共享的图像被馈送到多个鉴别器)
        freeze = False,                           # 可以选择在前向传播时冻结可学习的损失权重
        scale = 1.,                               # 缩放因子,默认为 1
        grad_step = True,                         # 是否进行梯度步骤,默认为 True
        **backward_kwargs                          # 其他后向传播参数

.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\mocks.py

# 导入 torch 中的 nn 模块
from torch import nn

# 定义一个带有多个损失函数的模拟网络类
class MockNetworkWithMultipleLosses(nn.Module):
    # 初始化函数,接受维度和损失函数数量作为参数
    def __init__(
        self,
        dim,
        num_losses = 2
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 定义网络的主干部分,包括线性层、SiLU 激活函数和另一个线性层
        self.backbone = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

        # 定义多个判别器,每个判别器都是一个线性层,数量由参数 num_losses 决定
        self.discriminators = nn.ModuleList([
            nn.Linear(dim, 1) for _ in range(num_losses)
        ])

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 将输入 x 通过主干部分得到输出
        backbone_output = self.backbone(x)

        # 初始化损失列表
        losses = []

        # 遍历每个判别器
        for discr in self.discriminators:
            # 计算判别器的输出作为损失
            loss = discr(backbone_output)
            # 将损失的均值添加到损失列表中
            losses.append(loss.mean())

        # 返回损失列表和主干部分的输出
        return losses, backbone_output

.\lucidrains\gradnorm-pytorch\gradnorm_pytorch\__init__.py

# 从 gradnorm_pytorch.gradnorm_pytorch 模块中导入 GradNormLossWeighter 类
# 从 gradnorm_pytorch.mocks 模块中导入 MockNetworkWithMultipleLosses 类
from gradnorm_pytorch.gradnorm_pytorch import GradNormLossWeighter
from gradnorm_pytorch.mocks import MockNetworkWithMultipleLosses

GradNorm - Pytorch

A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch

Increasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues.

Will be dog-fooded for SoundStream, MagViT2 as well as MetNet3

Appreciation

Install

$ pip install gradnorm-pytorch

Usage

import torch

from gradnorm_pytorch import (
    GradNormLossWeighter,
    MockNetworkWithMultipleLosses
)

# a mock network with multiple discriminator losses

network = MockNetworkWithMultipleLosses(
    dim = 512,
    num_losses = 4
)

# backbone shared parameter

backbone_parameter = network.backbone[-1].weight

# grad norm based loss weighter

loss_weighter = GradNormLossWeighter(
    num_losses = 4,
    learning_rate = 1e-4,
    restoring_force_alpha = 0.,                  # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.
    grad_norm_parameters = backbone_parameter
)

# mock input

mock_input = torch.randn(2, 512)
losses, backbone_output_activations = network(mock_input)

# backwards with the loss weights
# will update on each backward based on gradnorm algorithm

loss_weighter.backward(losses, retain_graph = True)

# if you would like to update the loss weights wrt activations just do the following instead

loss_weighter.backward(losses, backbone_output_activations)

You can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting.

loss_weighter = GradNormLossWeighter(
    loss_weights = [1., 10., 5., 2.],
    ...,
    frozen = True
)

# or you can also freeze it on invoking the instance

loss_weighter.backward(..., freeze = True)

For use with 🤗 Huggingface Accelerate, just pass in the Accelerator instance into the keyword accelerator on initialization

ex.

accelerator = Accelerator()

network = accelerator.prepare(network)

loss_weighter = GradNormLossWeighter(
    ...,
    accelerator = accelerator
)

# backwards will now use accelerator

Todo

  • take care of gradient accumulation
  • handle sets of loss weights
  • handle freezing of some loss weights, but not others
  • allow for a prior weighting, accounted for when calculating gradient targets

Citations

@article{Chen2017GradNormGN,
    title   = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks},
    author  = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich},
    journal = {ArXiv},
    year    = {2017},
    volume  = {abs/1711.02257},
    url     = {https://api.semanticscholar.org/CorpusID:4703661}
}