Lucidrains-系列项目源码解析-六十-

116 阅读17分钟

Lucidrains 系列项目源码解析(六十)

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\trainer.py

# 导入必要的库
from pathlib import Path
from functools import partial
from packaging import version
from contextlib import nullcontext, contextmanager

import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler

# 导入自定义的工具函数和类
from pytorch_custom_utils import (
    get_adam_optimizer,
    OptimizerWithWarmupSchedule,
    add_wandb_tracker_contextmanager
)

# 导入加速库
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# 导入类型检查相关库
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, Tuple, Type, List

# 导入指数移动平均库
from ema_pytorch import EMA

# 导入数据处理相关函数
from meshgpt_pytorch.data import custom_collate

# 导入版本号
from meshgpt_pytorch.version import __version__

# 导入 MeshGPT 相关模型
from meshgpt_pytorch.meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# 常量定义
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
    find_unused_parameters = True
)

# 辅助函数定义

# 判断变量是否存在
def exists(v):
    return v is not None

# 返回默认值
def default(v, d):
    return v if exists(v) else d

# 判断是否可以整除
def divisible_by(num, den):
    return (num % den) == 0

# 生成数据循环迭代器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 删除字典中指定的键
def maybe_del(d: dict, *keys):
    for key in keys:
        if key not in d:
            continue

        del d[key]

# 自动编码器训练器类定义

# 添加 WandB 追踪上下文管理器
@add_wandb_tracker_contextmanager()
class MeshAutoencoderTrainer(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        model: MeshAutoencoder,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int,
        grad_accum_every: int,
        val_dataset: Optional[Dataset] = None,
        val_every: int = 100,
        val_num_batches: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.,
        max_grad_norm: Optional[float] = None,
        ema_kwargs: dict = dict(),
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        checkpoint_every = 1000,
        checkpoint_folder = './checkpoints',
        data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'],
        warmup_steps = 1000,
        use_wandb_tracking = False
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        model,
        dataset,
        learning_rate,
        batch_size,
        optimizer_kwargs = {},
        scheduler = None,
        scheduler_kwargs = {},
        warmup_steps = 0,
        max_grad_norm = 1.0,
        grad_accum_every = 1,
        num_train_steps = None,
        checkpoint_every = None,
        checkpoint_folder = 'checkpoints',
        ema_kwargs = {},
        val_dataset = None,
        val_every = 1000,
        val_num_batches = 10,
        data_kwargs = {}
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 实验追踪器
        self.use_wandb_tracking = use_wandb_tracking

        # 如果使用 wandb 追踪
        if use_wandb_tracking:
            # 设置加速器参数中的日志记录方式为 'wandb'
            accelerator_kwargs['log_with'] = 'wandb'

        # 如果加速器参数中没有 'kwargs_handlers'
        if 'kwargs_handlers' not in accelerator_kwargs:
            # 设置加速器参数中的 'kwargs_handlers' 为默认的 DDP 参数
            accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]

        # 初始化加速器
        self.accelerator = Accelerator(**accelerator_kwargs)

        # 设置模型
        self.model = model

        # 如果是主进程
        if self.is_main:
            # 初始化 EMA 模型
            self.ema_model = EMA(model, **ema_kwargs)

        # 初始化优化器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = get_adam_optimizer(model.parameters(), lr = learning_rate, wd = weight_decay, **optimizer_kwargs),
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        # 初始化数据加载器
        self.dataloader = DataLoader(
            dataset,
            batch_size = batch_size,
            shuffle = True,
            drop_last = True,
            collate_fn = partial(custom_collate, pad_id = model.pad_id)
        )

        # 是否需要验证
        self.should_validate = exists(val_dataset)

        # 如果需要验证
        if self.should_validate:
            # 确保验证数据集不为空
            assert len(val_dataset) > 0, 'your validation dataset is empty'

            # 设置验证频率和验证批次数
            self.val_every = val_every
            self.val_num_batches = val_num_batches

            # 初始化验证数据加载器
            self.val_dataloader = DataLoader(
                val_dataset,
                batch_size = batch_size,
                shuffle = True,
                drop_last = True,
                collate_fn = partial(custom_collate, pad_id = model.pad_id)
            )

        # 如果数据集具有 'data_kwargs' 属性且不为空
        if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
            # 确保数据参数是字符串列表
            assert is_bearable(dataset.data_kwargs, List[str])
            self.data_kwargs = dataset.data_kwargs
        else:
            self.data_kwargs = data_kwargs

        # 准备模型和数据加载器
        (
            self.model,
            self.dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.dataloader
        )

        # 设置梯度累积步数和训练步数
        self.grad_accum_every = grad_accum_every
        self.num_train_steps = num_train_steps
        self.register_buffer('step', torch.tensor(0))

        # 设置检查点保存频率和文件夹
        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)

    # 获取 EMA tokenizer
    @property
    def ema_tokenizer(self):
        return self.ema_model.ema_model

    # 分词方法
    def tokenize(self, *args, **kwargs):
        return self.ema_tokenizer.tokenize(*args, **kwargs)

    # 日志记录方法
    def log(self, **data_kwargs):
        self.accelerator.log(data_kwargs, step = self.step.item())

    # 获取设备
    @property
    def device(self):
        return self.unwrapped_model.device

    # 是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 获取未包装的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 等待方法
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 打印方法
    def print(self, msg):
        return self.accelerator.print(msg)

    # 保存方法
    def save(self, path, overwrite = True):
        path = Path(path)
        # 如果覆盖或路径不存在
        assert overwrite or not path.exists()

        # 保存模型、EMA 模型、优化器等信息到文件
        pkg = dict(
            model = self.unwrapped_model.state_dict(),
            ema_model = self.ema_model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            version = __version__,
            step = self.step.item(),
            config = self.unwrapped_model._config
        )

        torch.save(pkg, str(path))
    # 加载模型参数
    def load(self, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型参数
        pkg = torch.load(str(path))

        # 检查模型版本是否与当前包版本匹配
        if version.parse(__version__) != version.parse(pkg['version']):
            self.print(f'loading saved mesh autoencoder at version {pkg["version"]}, but current package version is {__version__}')

        # 加载模型参数
        self.model.load_state_dict(pkg['model'])
        self.ema_model.load_state_dict(pkg['ema_model'])
        self.optimizer.load_state_dict(pkg['optimizer'])

        # 加载步数
        self.step.copy_(pkg['step'])

    # 获取下一个要传递给 forward 方法的数据
    def next_data_to_forward_kwargs(self, dl_iter) -> dict:
        # 获取下一个数据
        data = next(dl_iter)

        # 根据数据类型创建传递给 forward 方法的参数字典
        if isinstance(data, tuple):
            forward_kwargs = dict(zip(self.data_kwargs, data))

        elif isinstance(data, dict):
            forward_kwargs = data

        # 删除不需要的键
        maybe_del(forward_kwargs, 'texts', 'text_embeds')
        return forward_kwargs

    # 前向传播方法
    def forward(self):
        # 获取当前步数
        step = self.step.item()
        # 创建数据加载器迭代器
        dl_iter = cycle(self.dataloader)

        # 如果是主进程且需要验证
        if self.is_main and self.should_validate:
            val_dl_iter = cycle(self.val_dataloader)

        # 循环训练步数
        while step < self.num_train_steps:

            # 对于每个梯度累积步数
            for i in range(self.grad_accum_every):
                is_last = i == (self.grad_accum_every - 1)
                maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext

                # 获取下一个要传递给 forward 方法的参数
                forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)

                with self.accelerator.autocast(), maybe_no_sync():

                    # 执行模型前向传播
                    total_loss, (recon_loss, commit_loss) = self.model(
                        **forward_kwargs,
                        return_loss_breakdown = True
                    )

                    # 反向传播
                    self.accelerator.backward(total_loss / self.grad_accum_every)

            # 打印重建损失和压缩损失
            self.print(f'recon loss: {recon_loss.item():.3f} | commit loss: {commit_loss.sum().item():.3f}')

            # 记录损失
            self.log(
                total_loss = total_loss.item(),
                commit_loss = commit_loss.sum().item(),
                recon_loss = recon_loss.item()
            )

            # 更新优化器
            self.optimizer.step()
            self.optimizer.zero_grad()

            # 更新步数
            step += 1
            self.step.add_(1)

            # 等待
            self.wait()

            # 如果是主进程,更新 EMA 模型
            if self.is_main:
                self.ema_model.update()

            # 等待
            self.wait()

            # 如果是主进程且需要验证,并且步数是验证间隔的倍数
            if self.is_main and self.should_validate and divisible_by(step, self.val_every):

                total_val_recon_loss = 0.
                self.ema_model.eval()

                num_val_batches = self.val_num_batches * self.grad_accum_every

                # 验证模型
                for _ in range(num_val_batches):
                    with self.accelerator.autocast(), torch.no_grad():

                        forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)

                        val_loss, (val_recon_loss, val_commit_loss) = self.ema_model(
                            **forward_kwargs,
                            return_loss_breakdown = True
                        )

                        total_val_recon_loss += (val_recon_loss / num_val_batches)

                # 打印验证重建损失
                self.print(f'valid recon loss: {total_val_recon_loss:.3f}')

                # 记录验证损失
                self.log(val_loss = total_val_recon_loss)

            # 等待
            self.wait()

            # 如果是主进程且步数是保存检查点间隔的倍数
            if self.is_main and divisible_by(step, self.checkpoint_every):
                checkpoint_num = step // self.checkpoint_every
                self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.{checkpoint_num}.pt')

            # 等待
            self.wait()

        # 训练完成
        self.print('training complete')
# mesh transformer trainer

# 添加 WandB跟踪上下文管理器
@add_wandb_tracker_contextmanager()
class MeshTransformerTrainer(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        model: MeshTransformer,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int,
        grad_accum_every: int,
        learning_rate: float = 2e-4,
        weight_decay: float = 0.,
        max_grad_norm: Optional[float] = 0.5,
        val_dataset: Optional[Dataset] = None,
        val_every = 1,
        val_num_batches = 5,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        ema_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        checkpoint_every = 1000,
        checkpoint_folder = './checkpoints',
        data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges', 'texts'],
        warmup_steps = 1000,
        use_wandb_tracking = False
    ):
        super().__init__()

        # 实验跟踪器

        # 设置是否使用WandB跟踪
        self.use_wandb_tracking = use_wandb_tracking

        # 如果使用WandB跟踪,则设置加速器参数中的日志记录方式为'wandb'
        if use_wandb_tracking:
            accelerator_kwargs['log_with'] = 'wandb'

        # 如果加速器参数中没有'kwargs_handlers',则添加默认的DDP参数处理器
        if 'kwargs_handlers' not in accelerator_kwargs:
            accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]

        # 创建加速器对象
        self.accelerator = Accelerator(**accelerator_kwargs)

        # 设置模型
        self.model = model

        # 获取Adam优化器
        optimizer = get_adam_optimizer(
            model.parameters(),
            lr = learning_rate,
            wd = weight_decay,
            filter_by_requires_grad = True,
            **optimizer_kwargs
        )

        # 设置优化器和学习率调度器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = optimizer,
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        # 创建数据加载器
        self.dataloader = DataLoader(
            dataset,
            batch_size = batch_size,
            shuffle = True,
            drop_last = True,
            collate_fn = partial(custom_collate, pad_id = model.pad_id)
        )

        # 是否需要验证
        self.should_validate = exists(val_dataset)

        # 如果需要验证
        if self.should_validate:
            assert len(val_dataset) > 0, 'your validation dataset is empty'

            self.val_every = val_every
            self.val_num_batches = val_num_batches

            # 创建验证数据加载器
            self.val_dataloader = DataLoader(
                val_dataset,
                batch_size = batch_size,
                shuffle = True,
                drop_last = True,
                collate_fn = partial(custom_collate, pad_id = model.pad_id)
            )

        # 如果数据集有'data_kwargs'属性且存在
        if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
            assert is_bearable(dataset.data_kwargs, List[str])
            self.data_kwargs = dataset.data_kwargs
        else:
            self.data_kwargs = data_kwargs

        # 准备模型和数据加载器
        (
            self.model,
            self.dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.dataloader
        )

        # 设置梯度累积次数、训练步数、注册缓冲区
        self.grad_accum_every = grad_accum_every
        self.num_train_steps = num_train_steps
        self.register_buffer('step', torch.tensor(0))

        # 设置检查点保存频率和文件夹路径
        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)

    # 日志记录函数
    def log(self, **data_kwargs):
        self.accelerator.log(data_kwargs, step = self.step.item())

    # 设备属性
    @property
    def device(self):
        return self.unwrapped_model.device

    # 是否为主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 未包装模型属性
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 是否为本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 等待函数
    def wait(self):
        return self.accelerator.wait_for_everyone()
    # 打印消息,调用加速器的打印方法
    def print(self, msg):
        return self.accelerator.print(msg)

    # 获取下一个要传递给前向传播的数据,并返回关键字参数字典
    def next_data_to_forward_kwargs(self, dl_iter) -> dict:
        # 获取下一个数据
        data = next(dl_iter)

        # 如果数据是元组,则将数据关键字与数据值组成字典
        if isinstance(data, tuple):
            forward_kwargs = dict(zip(self.data_kwargs, data))

        # 如果数据是字典,则直接使用该字典
        elif isinstance(data, dict):
            forward_kwargs = data

        return forward_kwargs

    # 保存模型和优化器状态到指定路径
    def save(self, path, overwrite = True):
        path = Path(path)
        assert overwrite or not path.exists()

        # 构建要保存的数据包
        pkg = dict(
            model = self.unwrapped_model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            step = self.step.item(),
            version = __version__
        )

        # 使用torch保存数据包到指定路径
        torch.save(pkg, str(path))

    # 从指定路径加载模型和优化器状态
    def load(self, path):
        path = Path(path)
        assert path.exists()

        # 加载数据包
        pkg = torch.load(str(path))

        # 检查加载的模型版本与当前包版本是否一致
        if version.parse(__version__) != version.parse(pkg['version']):
            self.print(f'loading saved mesh transformer at version {pkg["version"]}, but current package version is {__version__}')

        # 加载模型和优化器状态
        self.model.load_state_dict(pkg['model'])
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step.copy_(pkg['step'])

    # 模型的前向传播方法
    def forward(self):
        step = self.step.item()
        dl_iter = cycle(self.dataloader)

        # 如果需要验证,则创建验证数据迭代器
        if self.should_validate:
            val_dl_iter = cycle(self.val_dataloader)

        # 循环训练步数
        while step < self.num_train_steps:

            # 对于每个梯度累积步数
            for i in range(self.grad_accum_every):
                is_last = i == (self.grad_accum_every - 1)
                maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext

                # 获取下一个要传递给前向传播的数据关键字参数
                forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)

                # 使用自动混合精度进行前向传播
                with self.accelerator.autocast(), maybe_no_sync():
                    loss = self.model(**forward_kwargs)

                    # 反向传播
                    self.accelerator.backward(loss / self.grad_accum_every)

            self.print(f'loss: {loss.item():.3f}')

            # 记录损失
            self.log(loss = loss.item())

            # 更新优化器
            self.optimizer.step()
            self.optimizer.zero_grad()

            step += 1
            self.step.add_(1)

            self.wait()

            # 如果是主进程且需要验证,并且当前步数是验证间隔的倍数
            if self.is_main and self.should_validate and divisible_by(step, self.val_every):

                total_val_loss = 0.
                self.unwrapped_model.eval()

                num_val_batches = self.val_num_batches * self.grad_accum_every

                # 验证损失计算
                for _ in range(num_val_batches):
                    with self.accelerator.autocast(), torch.no_grad():

                        forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)

                        val_loss = self.unwrapped_model(**forward_kwargs)

                        total_val_loss += (val_loss / num_val_batches)

                self.print(f'valid recon loss: {total_val_loss:.3f}')

                # 记录验证损失
                self.log(val_loss = total_val_loss)

            self.wait()

            # 如果是主进程且当前步数是保存检查点间隔的倍数
            if self.is_main and divisible_by(step, self.checkpoint_every):
                checkpoint_num = step // self.checkpoint_every
                self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.{checkpoint_num}.pt')

            self.wait()

        self.print('training complete')

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\version.py

# 定义当前代码的版本号为 '1.1.1'
__version__ = '1.1.1'

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\__init__.py

# 从 meshgpt_pytorch 包中导入 MeshAutoencoder 和 MeshTransformer 类
from meshgpt_pytorch.meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# 从 meshgpt_pytorch 包中导入 MeshAutoencoderTrainer 和 MeshTransformerTrainer 类
from meshgpt_pytorch.trainer import (
    MeshAutoencoderTrainer,
    MeshTransformerTrainer
)

# 从 meshgpt_pytorch 包中导入 DatasetFromTransforms、cache_text_embeds_for_dataset 和 cache_face_edges_for_dataset 函数
from meshgpt_pytorch.data import (
    DatasetFromTransforms,
    cache_text_embeds_for_dataset,
    cache_face_edges_for_dataset
)

MeshGPT - Pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch

Will also add text conditioning, for eventual text-to-3d asset

Please join Join us on Discord if you are interested in collaborating with others to replicate this work

Appreciation

  • StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

  • Einops for making my life easy

  • Marcus for the initial code review (pointing out some missing derived features) as well as running the first successful end-to-end experiments

  • Marcus for the first successful training of a collection of shapes conditioned on labels

  • Quexi Ma for finding numerous bugs with automatic eos handling

  • Yingtian for finding a bug with the gaussian blurring of the positions for spatial label smoothing

  • Marcus yet again for running the experiments to validate that it is possible to extend the system from triangles to quads

Install

$ pip install meshgpt-pytorch

Usage

import torch

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# autoencoder

autoencoder = MeshAutoencoder(
    num_discrete_coors = 128
)

# mock inputs

vertices = torch.randn((2, 121, 3))            # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3))      # (batch, num faces, vertices (3))

# make sure faces are padded with `-1` for variable lengthed meshes

# forward in the faces

loss = autoencoder(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768
)

loss = transformer(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates, face_mask = transformer.generate()

# (batch, num faces, vertices (3), coordinates (3)), (batch, num faces)
# now post process for the generated 3d asset

For text-conditioned 3d shape synthesis, simply set condition_on_text = True on your MeshTransformer, and then pass in your list of descriptions as the texts keyword argument

ex.

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768,
    condition_on_text = True
)


loss = transformer(
    vertices = vertices,
    faces = faces,
    texts = ['a high chair', 'a small teapot'],
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets conditioned on text

faces_coordinates, face_mask = transformer.generate(texts = ['a long table'])

If you want to tokenize meshes, for use in your multimodal transformer, simply invoke .tokenize on your autoencoder (or same method on autoencoder trainer instance for the exponentially smoothed model)


mesh_token_ids = autoencoder.tokenize(
    vertices = vertices,
    faces = faces
)

# (batch, num face vertices, residual quantized layer)

Todo

  • autoencoder

    • encoder sageconv with torch geometric
    • proper scatter mean accounting for padding for meaning the vertices and RVQ the vertices before gathering back for decoder
    • complete decoder and reconstruction loss + commitment loss
    • handle variable lengthed faces
    • add option to use residual LFQ, latest quantization development that scales code utilization
    • xcit linear attention in encoder and decoder
    • figure out how to auto-derive face_edges directly from faces and vertices
    • embed any derived values (area, angles, etc) from the vertices before sage convs
    • add an extra graph conv stage in the encoder, where vertices are enriched with their connected vertex neighbors, before aggregating into faces. make optional
    • allow for encoder to noise the vertices, so autoencoder is a bit denoising. consider conditioning decoder on noise level, if varying
  • transformer

    • properly mask out eos logit during generation
    • make sure it trains
      • take care of sos token automatically
      • take care of eos token automatically if sequence length or mask is passed in
    • handle variable lengthed faces
      • on forwards
      • on generation, do all eos logic + substitute everything after eos with pad id
    • generation + cache kv
  • trainer wrapper with hf accelerate

    • autoencoder - take care of ema
    • transformer
  • text conditioning using own CFG library

    • complete preliminary text conditioning
    • make sure CFG library can support passing in arguments to the two separate calls when cond scaling (as well as aggregating their outputs)
    • polish up the magic dataset decorator and see if it can be moved to CFG library
  • hierarchical transformers (using the RQ transformer)

  • fix caching in simple gateloop layer in other repo

  • local attention

  • fix kv caching for two-staged hierarchical transformer - 7x faster now, and faster than original non-hierarchical transformer

  • fix caching for gateloop layers

  • allow for customization of model dimensions of fine vs coarse attention network

  • figure out if autoencoder is really necessary - it is necessary, ablations are in the paper

    • when mesh discretizer is passed in, one can inject inter-face attention with the relative distance
    • additional embeddings (angles, area, normal), can also be appended before coarse transformer attention
  • make transformer efficient

    • reversible networks
  • speculative decoding option

  • spend a day on documentation

Citations

@inproceedings{Siddiqui2023MeshGPTGT,
    title   = {MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers},
    author  = {Yawar Siddiqui and Antonio Alliegro and Alexey Artemov and Tatiana Tommasi and Daniele Sirigatti and Vladislav Rosov and Angela Dai and Matthias Nie{\ss}ner},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265457242}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, 
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Lee2022AutoregressiveIG,
    title   = {Autoregressive Image Generation using Residual Quantization},
    author  = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11513-11522},
    url     = {https://api.semanticscholar.org/CorpusID:247244535}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

.\lucidrains\meshgpt-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('meshgpt_pytorch/version.py').read())

# 设置包的元信息
setup(
  name = 'meshgpt-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  version = __version__,  # 版本号
  license='MIT',  # 许可证
  description = 'MeshGPT Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/meshgpt-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'mesh generation'
  ],
  install_requires=[  # 安装依赖
    'accelerate>=0.25.0',
    'beartype',
    'classifier-free-guidance-pytorch>=0.5.1',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'ema-pytorch',
    'local-attention>=1.9.0',
    'gateloop-transformer>=0.2.2',
    'numpy',
    'pytorch-custom-utils>=0.0.9',
    'taylor-series-linear-attention>=0.1.6',
    'torch>=2.1',
    'torch_geometric',
    'torchtyping',
    'tqdm',
    'vector-quantize-pytorch>=1.12.8',
    'x-transformers>=1.26.0',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: prize.hutter1.net/

.\lucidrains\metaformer-gpt\metaformer_gpt\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数,用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, max_seq_len=2048, pad_value=0):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成序列的方法,使用 torch.no_grad() 装饰器和 eval_decorator 装饰器
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start_tokens,
        seq_len,
        eos_token=None,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            logits = self.net(out, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres=filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = out == eos_token

                if is_eos_token.any(dim=-1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]
        return out

    # 前向传播方法,计算交叉熵损失
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        logits = self.net(x_inp, **kwargs)
        return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)

.\lucidrains\metaformer-gpt\metaformer_gpt\metaformer_gpt.py

import torch
from torch import nn, einsum
from einops import rearrange, repeat

from scipy.fftpack import next_fast_len

# 辅助函数

def cummean(x, *, dim):
    # 计算累积均值
    numer = x.cumsum(dim = dim)
    denom = torch.arange(x.shape[1], device = x.device) + 1
    return numer / rearrange(denom, '... -> ... 1')

def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
    # 使用傅立叶技巧进行 O(N log(N)) 1维卷积

    N = x.shape[dim]
    M = weights.shape[weight_dim]

    fast_len = next_fast_len(N + M - 1)

    # 对输入信号和权重进行傅立叶变换
    f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
    f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)

    # 计算频域乘积
    f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1')
    out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
    out = out.roll(-1, dims = (dim,))

    # 选择输出的部分
    indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
    out = out.index_select(dim, indices)
    return out

# 类

class MeanCenteringPool(nn.Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.proj = nn.Linear(dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)
        x = cummean(x, dim = 1) - x
        return self.proj(x)

class MultiheadExponentialTimeDecay(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 64
    ):
        super().__init__()
        self.heads = heads
        inner_dim = heads * dim_head

        self.norm = nn.LayerNorm(dim)
        self.alpha = nn.Parameter(torch.randn(heads))

        self.project_in = nn.Linear(dim, inner_dim, bias = False)
        self.project_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        b, n, d, h, device = *x.shape, self.heads, x.device

        x = self.norm(x)

        # 线性投影

        x = self.project_in(x)

        # 分割头部

        x = rearrange(x, 'b n (h d) -> b h n d', h = h)

        # 准备指数 alpha

        alpha = self.alpha.sigmoid()
        alpha = rearrange(alpha, 'h -> h 1')

        # 计算权重

        arange = torch.arange(n, device = device)
        weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,))
        output = conv1d_fft(x, weights)

        # 合并头部

        output = rearrange(output, 'b h n d -> b n (h d)')
        return self.project_out(output)

def FeedForward(dim, mult = 4):
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, hidden_dim, bias = False),
        nn.GELU(),
        nn.Linear(hidden_dim, dim, bias = False)
    )

class MetaformerGPT(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 16,
        dim_head = 32,
        max_seq_len = 2048,
        ff_mult = 4
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MultiheadExponentialTimeDecay(dim, heads = heads, dim_head = dim_head),
                MeanCenteringPool(dim),
                FeedForward(dim, mult = ff_mult)
            ]))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

    def forward(self, x):
        n, device = x.shape[1], x.device

        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(n, device = device))

        for mh_esa, pool, ff in self.layers:
            x = mh_esa(x) + x
            x = pool(x) + x
            x = ff(x) + x

        return self.to_logits(x)

.\lucidrains\metaformer-gpt\metaformer_gpt\__init__.py

# 从 metaformer_gpt 包中导入 MetaformerGPT 和 MultiheadExponentialTimeDecay 类
from metaformer_gpt.metaformer_gpt import MetaformerGPT, MultiheadExponentialTimeDecay

Metaformer - GPT (wip)

Implementation of Metaformer, but in an autoregressive manner. In particular, they propose simply using mean centering as a way to do token mixing in a parameter-less fashion, alternating with feedforwards.

Install

$ pip install metaformer-gpt

Usage

import torch
from metaformer_gpt import MetaformerGPT

gpt = MetaformerGPT(
    num_tokens = 256,
    dim = 512,
    depth = 8
)

ids = torch.randint(0, 256, (1, 1024))
logits = gpt(ids) # (1, 1024, 256)

Citations

@article{Yu2021MetaFormerIA,
    title   = {MetaFormer is Actually What You Need for Vision},
    author  = {Weihao Yu and Mi Luo and Pan Zhou and Chenyang Si and Yichen Zhou and Xinchao Wang and Jiashi Feng and Shuicheng Yan},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2111.11418}
}
@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\metaformer-gpt\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'metaformer-gpt',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.5',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Metaformer - GPT',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/metaformer-gpt',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-less'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'scipy',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\metaformer-gpt\train.py

# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from metaformer_gpt import MetaformerGPT
from metaformer_gpt.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = MetaformerGPT(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 16,
    dim_head = 32
)

model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\metnet3-pytorch\metnet3_pytorch\metnet3_pytorch.py

# 导入必要的库
from pathlib import Path
from functools import partial
from collections import namedtuple
from contextlib import contextmanager

import torch
from torch import nn, Tensor, einsum
import torch.distributed as dist
from torch.autograd import Function
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential

# 导入 einops 库中的函数和层
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

# 导入 beartype 库中的类型注解
from beartype import beartype
from beartype.typing import Tuple, Union, List, Optional, Dict, Literal

import pickle

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 将单个元素打包成指定模式的元组
def pack_one(x, pattern):
    return pack([x], pattern)

# 从元组中解包单个元素
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 安全除法,避免分母为零
def safe_div(num, den, eps = 1e-10):
    return num / den.clamp(min = eps)

# 张量归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 准备在分布式训练中使用的批量归一化

# 根据是否处于分布式环境选择使用 SyncBatchNorm 还是 BatchNorm2d
def MaybeSyncBatchnorm2d(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d

# 冻结批量归一化层
@contextmanager
def freeze_batchnorm(bn):
    assert not exists(next(bn.parameters(), None))

    was_training = bn.training
    was_tracking_stats = bn.track_running_stats
    bn.eval()
    bn.track_running_stats = False

    yield bn

    bn.train(was_training)
    bn.track_running_stats = was_tracking_stats

# 损失缩放

# 自定义损失缩放函数
class LossScaleFunction(Function):
    @staticmethod
    def forward(ctx, x, eps):
        ctx.eps = eps
        assert x.ndim == 4
        return x

    @staticmethod
    def backward(ctx, grads):
        num_channels = grads.shape[1]

        safe_div_ = partial(safe_div, eps = ctx.eps)

        weight = safe_div_(1., grads.norm(p = 2, keepdim = True, dim = (-1, -2)))
        l1_normed_weight = safe_div_(weight, weight.sum(keepdim = True, dim = 1))

        scaled_grads = num_channels * l1_normed_weight * grads

        return scaled_grads, None

# 损失缩放器
class LossScaler(Module):
    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        return LossScaleFunction.apply(x, self.eps)

# 中心裁剪

# 中心填充模块
class CenterPad(Module):
    def __init__(self, target_dim):
        super().__init__()
        self.target_dim = target_dim

    def forward(self, x):
        target_dim = self.target_dim
        *_, height, width = x.shape
        assert target_dim >= height and target_dim >= width

        height_pad = target_dim - height
        width_pad = target_dim - width
        left_height_pad = height_pad // 2
        left_width_pad = width_pad // 2

        return F.pad(x, (left_height_pad, height_pad - left_height_pad, left_width_pad, width_pad - left_width_pad), value = 0.)

# 中心裁剪模块
class CenterCrop(Module):
    def __init__(self, crop_dim):
        super().__init__()
        self.crop_dim = crop_dim

    def forward(self, x):
        crop_dim = self.crop_dim
        *_, height, width = x.shape
        assert (height >= crop_dim) and (width >= crop_dim)

        cropped_height_start_idx = (height - crop_dim) // 2
        cropped_width_start_idx = (width - crop_dim) // 2

        height_slice = slice(cropped_height_start_idx, cropped_height_start_idx + crop_dim)
        width_slice = slice(cropped_width_start_idx, cropped_width_start_idx + crop_dim)
        return x[..., height_slice, width_slice]

# 下采样和上采样

# 下采样使用最大池化,上采样使用转置卷积
# todo: 弄清楚从 4km 到 1km 的 4 倍上采样

# 2 倍下采样
Downsample2x = partial(nn.MaxPool2d, kernel_size = 2, stride = 2)

# 2 倍上采样
def Upsample2x(dim, dim_out = None):
    # 如果未提供输出维度,则使用输入维度作为输出维度
    dim_out = default(dim_out, dim)
    # 返回一个转置卷积层,输入维度为dim,输出维度为dim_out,卷积核大小为2,步长为2
    return nn.ConvTranspose2d(dim, dim_out, kernel_size = 2, stride = 2)
# 定义一个条件可选的 ResNet 块
class Block(Module):
    def __init__(self, dim, dim_out):
        super().__init__()
        # 使用卷积层进行投影
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        # 使用通道层归一化
        self.norm = ChanLayerNorm(dim_out)
        # 使用 ReLU 激活函数
        self.act = nn.ReLU()

    def forward(self, x, scale_shift = None):
        # 对输入进行投影
        x = self.proj(x)
        # 对投影结果进行归一化
        x = self.norm(x)

        # 如果存在 scale_shift 参数,则进行缩放和平移
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        # 对结果进行激活
        x = self.act(x)
        return x

# 定义一个 ResNet 块
class ResnetBlock(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        cond_dim = None
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.mlp = None

        # 如果存在条件维度,则创建一个 MLP
        if exists(cond_dim):
            self.mlp = Sequential(
                nn.ReLU(),
                nn.Linear(cond_dim, dim_out * 2)
            )

        # 创建两个 Block 实例
        self.block1 = Block(dim, dim_out)
        self.block2 = Block(dim_out, dim_out)
        # 如果输入维度和输出维度不同,则使用卷积层进行投影
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, cond = None):

        scale_shift = None

        # 断言条件:MLP 和条件参数 cond 必须同时存在或同时不存在
        assert not (exists(self.mlp) ^ exists(cond))

        # 如果存在 MLP 和条件参数 cond,则进行处理
        if exists(self.mlp) and exists(cond):
            cond = self.mlp(cond)
            cond = rearrange(cond, 'b c -> b c 1 1')
            scale_shift = cond.chunk(2, dim = 1)

        # 对输入进行第一个 Block 处理
        h = self.block1(x, scale_shift = scale_shift)

        # 对第一个 Block 处理结果进行第二个 Block 处理
        h = self.block2(h)

        # 返回结果加上残差连接
        return h + self.res_conv(x)

# 定义一个包含多个 ResNet 块的模块
class ResnetBlocks(Module):
    def __init__(
        self,
        dim,
        *,
        dim_in = None,
        depth = 1,
        cond_dim = None
    ):
        super().__init__()
        curr_dim = default(dim_in, dim)

        blocks = []
        # 根据深度循环创建多个 ResNet 块
        for _ in range(depth):
            blocks.append(ResnetBlock(dim = curr_dim, dim_out = dim, cond_dim = cond_dim))
            curr_dim = dim

        self.blocks = ModuleList(blocks)

    def forward(self, x, cond = None):

        for block in self.blocks:
            x = block(x, cond = cond)

        return x

# 多头 RMS 归一化,用于查询/键归一化注意力
class RMSNorm(Module):
    def __init__(
        self,
        dim,
        *,
        heads
    ):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 在 ResNet 块中使用层归一化的原因
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.g + self.b

# MBConv

# 定义一个 Squeeze-and-Excitation 模块
class SqueezeExcitation(Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        # 构建门控网络
        self.gate = Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# 定义一个 MBConv 残差模块
class MBConvResidual(Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 定义一个 Dropout 模块
class Dropsample(Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
    # 定义一个前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的设备信息
        device = x.device

        # 如果概率为 0 或者不处于训练状态,则直接返回输入张量 x
        if self.prob == 0. or (not self.training):
            return x

        # 生成一个与输入张量 x 形状相同的随机掩码,用于随机丢弃部分数据
        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() > self.prob
        # 对输入张量 x 进行随机丢弃操作,并进行归一化处理
        return x * keep_mask / (1 - self.prob)
# 定义一个 MBConv 模块,用于 MobileNetV3 的基本块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    # 计算隐藏层维度
    hidden_dim = int(expansion_rate * dim_out)
    # 如果 downsample 为真,则步长为 2,否则为 1
    stride = 2 if downsample else 1

    # 创建一个 MaybeSyncBatchnorm2d 类的实例
    batchnorm_klass = MaybeSyncBatchnorm2d()

    # 构建网络结构
    net = Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        batchnorm_klass(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        batchnorm_klass(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        batchnorm_klass(dim_out)
    )

    # 如果输入维度等于输出维度且不下采样,则添加 MBConvResidual 模块
    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# attention related classes

# 定义一个 XCAttention 类,实现特定的线性注意力机制
class XCAttention(Module):
    """
    this specific linear attention was proposed in https://arxiv.org/abs/2106.09681 (El-Nouby et al.)
    """

    @beartype
    def __init__(
        self,
        *,
        dim,
        cond_dim: Optional[int] = None,
        dim_head = 32,
        heads = 8,
        scale = 8,
        flash = False,
        dropout = 0.
    ):
        super().__init__()
        dim_inner = dim_head * heads

        self.has_cond = exists(cond_dim)

        self.film = None

        # 如果有条件输入,则构建 FILM 网络
        if self.has_cond:
            self.film = Sequential(
                nn.Linear(cond_dim, dim * 2),
                nn.SiLU(),
                nn.Linear(dim * 2, dim * 2),
                Rearrange('b (r d) -> r b 1 d', r = 2)
            )

        # LayerNorm 层
        self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond)

        # QKV 线性映射
        self.to_qkv = Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
        )

        self.scale = scale

        self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

        self.attn_dropout = nn.Dropout(dropout)

        # 输出映射
        self.to_out = Sequential(
            Rearrange('b h d n -> b n (h d)'),
            nn.Linear(dim_inner, dim)
        )

    # 前向传播函数
    def forward(
        self,
        x,
        cond: Optional[Tensor] = None
    ):
        x = rearrange(x, 'b c h w -> b h w c')
        x, ps = pack_one(x, 'b * c')

        x = self.norm(x)

        # 条件输入
        if exists(self.film):
            assert exists(cond)

            gamma, beta = self.film(cond)
            x = x * gamma + beta

        # 余弦相似度线性注意力机制
        q, k, v = self.to_qkv(x)

        q, k = map(l2norm, (q, k))
        q = q * self.temperature.exp()

        sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j n -> b h i n', attn, v)

        out = self.to_out(out)

        out = unpack_one(out, ps, 'b * c')
        return rearrange(out, 'b h w c -> b c h w')

# 定义一个 Attention 类,实现注意力机制
class Attention(Module):
    def __init__(
        self,
        dim,
        cond_dim = None,
        heads = 32,
        dim_head = 32,
        dropout = 0.,
        window_size = 8,
        num_registers = 1
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言寄存器数量大于0
        assert num_registers > 0
        # 断言维度应该可以被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算内部维度
        dim_inner = dim_head * heads
        self.heads = heads
        # 缩放因子
        self.scale = dim_head ** -0.5

        # 检查是否有条件
        self.has_cond = exists(cond_dim)

        self.film = None

        # 如果有条件
        if self.has_cond:
            # 创建 FILM 模块
            self.film = Sequential(
                nn.Linear(cond_dim, dim * 2),
                nn.SiLU(),
                nn.Linear(dim * 2, dim * 2),
                Rearrange('b (r d) -> r b 1 d', r = 2)
            )

        # 归一化层
        self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond)

        # 线性变换到查询、键、值
        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

        # 查询和键的 RMS 归一化
        self.q_norm = RMSNorm(dim_head, heads = heads)
        self.k_norm = RMSNorm(dim_head, heads = heads)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

        # 相对位置偏差

        num_rel_pos_bias = (2 * window_size - 1) ** 2

        # 创建相对位置偏差的 Embedding
        self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)

        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(
        self,
        x: Tensor,
        cond: Optional[Tensor] = None
    ):
        # 获取设备、头数、偏差索引
        device, h, bias_indices = x.device, self.heads, self.rel_pos_indices

        # 归一化输入
        x = self.norm(x)

        # 条件
        if exists(self.film):
            assert exists(cond)

            gamma, beta = self.film(cond)
            x = x * gamma + beta

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q, k = self.q_norm(q), self.k_norm(k)

        # 相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加位置偏差
        bias = self.rel_pos_bias(bias_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 MaxViT 的类,继承自 Module 类
class MaxViT(Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        dim,  # 特征维度
        depth,  # 模型深度
        cond_dim = 32,   # 用于条件化的前导时间嵌入
        heads = 32,  # 多头注意力机制中的头数
        dim_head = 32,  # 每个头的维度
        window_size = 8,  # 窗口大小
        mbconv_expansion_rate = 4,  # MBConv 层的扩张率
        mbconv_shrinkage_rate = 0.25,  # MBConv 层的收缩率
        dropout = 0.1,  # 丢弃率
        num_register_tokens = 4  # 寄存器令牌数量
    ):
        super().__init__()
        # 如果 depth 是整数,则转换为元组
        depth = (depth,) if isinstance(depth, int) else depth
        # 断言寄存器令牌数量大于0
        assert num_register_tokens > 0

        self.cond_dim = cond_dim

        # 变量

        num_stages = len(depth)

        # 计算每个阶段的维度
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        self.layers = nn.ModuleList([])

        # 窗口大小

        self.window_size = window_size

        self.register_tokens = nn.ParameterList([])

        # 遍历各个阶段

        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim

                # 创建 MBConv 层
                conv = MBConv(
                    stage_dim_in,
                    layer_dim,
                    downsample = is_first,
                    expansion_rate = mbconv_expansion_rate,
                    shrinkage_rate = mbconv_shrinkage_rate
                )

                # 创建块级别的注意力机制
                block_attn = Attention(dim = layer_dim, cond_dim = cond_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)

                # 创建网格级别的注意力机制
                grid_attn = Attention(dim = layer_dim, cond_dim = cond_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)

                # 创建寄存器令牌
                register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))

                # 将 MBConv 层、块级别注意力机制、网格级别注意力机制组合成一个模块列表
                self.layers.append(ModuleList([
                    conv,
                    block_attn,
                    grid_attn
                ]))

                # 将寄存器令牌添加到参数列表中
                self.register_tokens.append(register_tokens)

    # 前向传播函数,接受输入张量 x 和条件张量 cond
    def forward(
        self,
        x: Tensor,
        cond: Tensor
    ):
        # 断言条件的形状与输入 x 的形状一致
        assert cond.shape == (x.shape[0], self.cond_dim)

        # 获取输入 x 的批量大小和窗口大小
        b, w = x.shape[0], self.window_size

        # 遍历每个层和对应的注册令牌
        for (conv, block_attn, grid_attn), register_tokens in zip(self.layers, self.register_tokens):
            # 对输入 x 进行卷积操作
            x = conv(x)

            # block-like attention

            # 重新排列输入 x 的维度
            x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌
            r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入 x 进行块状注意力操作,并与原始输入相加
            x = block_attn(x, cond = cond) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')

            r = unpack_one(r, register_batch_ps, '* n d')

            # grid-like attention

            # 重新排列输入 x 的维度
            x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌
            r = reduce(r, 'b x y n d -> b n d', 'mean')
            r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入 x 进行网格状注意力操作,并与原始输入相加
            x = grid_attn(x, cond = cond) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')

        # 返回处理后的输入 x
        return x
# 定义一个命名元组 Predictions,包含 surface、hrrr、precipitation 三个字段
Predictions = namedtuple('Predictions', [
    'surface',
    'hrrr',
    'precipitation'
])

# 定义一个命名元组 LossBreakdown,包含 surface、hrrr、precipitation 三个字段
LossBreakdown = namedtuple('LossBreakdown', [
    'surface',
    'hrrr',
    'precipitation'
])

# 定义一个类 MetNet3,继承自 Module
class MetNet3(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        *,
        dim = 512,
        num_lead_times = 722,
        lead_time_embed_dim = 32,
        input_spatial_size = 624,
        attn_depth = 12,
        attn_dim_head = 64,
        attn_heads = 32,
        attn_dropout = 0.1,
        vit_window_size = 8,
        vit_mbconv_expansion_rate = 4,
        vit_mbconv_shrinkage_rate = 0.25,
        input_2496_channels = 2 + 14 + 1 + 2 + 20,
        input_4996_channels = 16 + 1,
        surface_and_hrrr_target_spatial_size = 128,
        precipitation_target_bins: Dict[str, int] = dict(
            mrms_rate = 512,
            mrms_accumulation = 512
        ),
        surface_target_bins: Dict[str, int] = dict(
            omo_temperature = 256,
            omo_dew_point = 256,
            omo_wind_speed = 256,
            omo_wind_component_x = 256,
            omo_wind_component_y = 256,
            omo_wind_direction = 180
        ),
        hrrr_norm_strategy: Union[
            Literal['none'],
            Literal['precalculated'],
            Literal['sync_batchnorm']
        ] = 'none',
        hrrr_channels = 617,
        hrrr_norm_statistics: Optional[Tensor] = None,
        hrrr_loss_weight = 10,
        crop_size_post_16km = 48,
        resnet_block_depth = 2,
    
    # 类方法,从路径加载模型
    @classmethod
    def init_and_load_from(cls, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型
        pkg = torch.load(str(path), map_location = 'cpu')

        # 断言模型配置信息在加载的包中
        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        # 从包中加载配置信息
        config = pickle.loads(pkg['config'])
        # 创建模型实例
        tokenizer = cls(**config)
        # 加载模型
        tokenizer.load(path, strict = strict)
        return tokenizer

    # 保存模型
    def save(self, path, overwrite = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径不存在或允许覆盖
        assert overwrite or not path.exists(), f'{str(path)} already exists'

        # 构建保存的包
        pkg = dict(
            model_state_dict = self.state_dict(),
            config = self._configs
        )

        # 保存模型
        torch.save(pkg, str(path))

    # 加载模型
    def load(self, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型
        pkg = torch.load(str(path))
        state_dict = pkg.get('model_state_dict')

        # 断言状态字典存在
        assert exists(state_dict)

        # 加载模型状态字典
        self.load_state_dict(state_dict, strict = strict)

    # 前向传播方法
    @beartype
    def forward(
        self,
        *,
        lead_times,
        hrrr_input_2496,
        hrrr_stale_state,
        input_2496,
        input_4996,
        surface_targets: Optional[Dict[str, Tensor]] = None,
        precipitation_targets: Optional[Dict[str, Tensor]] = None,
        hrrr_target: Optional[Tensor] = None,