Lucidrains-系列项目源码解析-七十-

62 阅读24分钟

Lucidrains 系列项目源码解析(七十)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam
# 从 lion_pytorch 模块中导入 Lion 类

# 将参数分为需要权重衰减和不需要权重衰减的两组参数
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    use_lion = True,
    **kwargs
):
    # 根据是否需要过滤梯度为零的参数来更新参数列表
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and wd > 0:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 如果使用 Lion 优化器
    if use_lion:
        return Lion(params, lr = lr, betas = betas, weight_decay = wd)

    # 如果不需要权重衰减
    if wd == 0:
        return Adam(params, lr = lr, betas = betas, eps = eps)

    # 使用 AdamW 优化器
    return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\palm.py

# 导入数学库
import math
# 导入拷贝库
import copy
# 导入路径库
from pathlib import Path
# 导入命名元组库
from collections import namedtuple
# 导入装饰器库
from functools import wraps
# 导入zip_longest函数
from itertools import zip_longest

# 导入进度条库
from tqdm import tqdm
# 导入beartype库
from beartype import beartype
# 导入beartype中的Tuple和Optional
from beartype.typing import Tuple, Optional

# 导入torch库
import torch
# 从torch中导入einsum和nn
from torch import einsum, nn
# 从torch.nn中导入functional模块
import torch.nn.functional as F

# 导入einops库
from einops import rearrange, repeat, reduce, pack, unpack
# 从einops.layers.torch中导入Rearrange和Reduce
from einops.layers.torch import Rearrange, Reduce

# 从palm_rlhf_pytorch.attention中导入Attention
from palm_rlhf_pytorch.attention import Attention
# 从palm_rlhf_pytorch.utils中导入top_p, top_k, masked_mean, gumbel_sample, eval_decorator
from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
# 从palm_rlhf_pytorch.lora中导入LoRA

# 函数和装饰器

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 对输入张量进行L2范数归一化
def l2norm(t):
    return F.normalize(t, dim=-1)

# 标准化
# 他们使用没有偏置的layernorm,这是PyTorch不提供的功能

# 标准化层
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 残差连接

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        y = self.fn(x, **kwargs)

        if not any([t.requires_grad for t in (x, y)]):
            return x.add_(y)

        return y + x

# 旋转位置嵌入带xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, scale_base=512, use_xpos=True):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.use_xpos = use_xpos
        self.scale_base = scale_base
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.register_buffer('scale', scale)

    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim=-1)

        if not self.use_xpos:
            return freqs, torch.ones(1, device=device)

        power = (t - (seq_len // 2)) / self.scale_base
        scale = self.scale ** rearrange(power, 'n -> n 1')
        scale = torch.cat((scale, scale), dim=-1)

        return freqs, scale

# 旋转半个张量
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t, scale=1.):
    return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)

# 经典的Noam Shazeer论文,但这里他们使用SwiGLU而不是更流行的GEGLU来门控前馈
# https://arxiv.org/abs/2002.05202

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

# 并行注意力和前馈与残差
# 王等人和GPT-J的EleutherAI发现

class ParallelTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        causal=True,
        heads=8,
        qk_rmsnorm=False,
        qk_scale=8,
        ff_mult=4,
        attn_dropout=0.,
        ff_dropout=0.,
        use_xpos=True,
        xpos_scale_base=512,
        flash_attn=False,
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)

        # 计算注意力内部维度
        attn_inner_dim = dim_head * heads
        # 计算前馈内部维度
        ff_inner_dim = dim * ff_mult
        # 定义融合维度
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        # 设置是否进行 qk rmsnorm
        self.qk_rmsnorm = qk_rmsnorm

        if qk_rmsnorm:
            # 初始化 q 的缩放参数
            self.q_scale = nn.Parameter(torch.ones(dim_head))
            # 初始化 k 的缩放参数
            self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 初始化注意力模块
        self.attend = Attention(
            causal = causal,
            dropout = attn_dropout,
            use_flash_attn = flash_attn
        )

        # 设置头数
        self.heads = heads
        # 设置缩放因子
        self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
        # 设置是否是因果关系
        self.causal = causal

        # 初始化旋转嵌入
        self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal)

        # 初始化融合的注意力和前馈投影
        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)

        # 设置是否使用 Flash Attention
        self.flash_attn = flash_attn
        # 初始化注意力输出层
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
        # 初始化注意力的 Dropout 层
        self.attn_dropout = nn.Dropout(attn_dropout)
        # 设置 Flash Attention 的 Dropout
        self.flash_attn_dropout = attn_dropout

        # 并行前馈尾部

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Dropout(ff_dropout),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # 用于缓存因果掩码和旋转嵌入

        self.register_buffer("pos_emb", None, persistent=False)
        self.register_buffer("pos_emb_scale", None, persistent=False)

    def get_rotary_embedding(self, n, device):
        if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n], self.pos_emb_scale[:n]

        pos_emb, scale = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        self.register_buffer("pos_emb_scale", scale, persistent=False)
        return pos_emb, scale

    def forward(
        self,
        x,
        mask = None,
        finetune_modules = None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # 预 Layernorm

        x = self.norm(x)

        # 注意力查询、键、值和前馈内部

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 调整 LORAS

        lora_q = lora_k = lora_v = lora_o = None

        if exists(finetune_modules):
            lora_q, lora_k, lora_v, lora_o = finetune_modules
            q = q + lora_q(x)
            k = k + lora_k(x)
            v = v + lora_v(x)

        # 分割头部
        # 他们使用多查询单键值注意力,另一篇 Noam Shazeer 的论文
        # 他们发现在一定规模之后没有性能损失,并且解码更有效
        # https://arxiv.org/abs/1911.02150

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # qk rmsnorm

        if self.qk_rmsnorm:
            q, k = map(l2norm, (q, k))
            q = q * self.q_scale
            k = k * self.k_scale

        # 使用 xpos 衰减的旋转嵌入以获得更好的长度外推

        positions, scale = self.get_rotary_embedding(n, device)

        q = apply_rotary_pos_emb(positions, q, scale)
        k = apply_rotary_pos_emb(positions, k, scale ** -1)

        # 注意力函数,常规或 Flash

        out = self.attend(q, k, v, mask = mask)

        # 合并头部

        out = rearrange(out, "b h n d -> b n (h d)")

        attn_out = self.attn_out(out)

        ff_out = self.ff_out(ff)

        if exists(lora_o):
            attn_out = attn_out + lora_o(out)

        return attn_out + ff_out
# 定义一个名为 PaLM 的类,继承自 nn.Module 类,用于实现一个基于 Transformer 的模型
@beartype
class PaLM(nn.Module):
    # 初始化函数,接收多个参数用于配置模型的各种属性
    def __init__(
        self,
        *,
        dim,  # 模型的维度
        num_tokens,  # token 的数量
        depth,  # Transformer 的深度
        causal = True,  # 是否使用 causal attention
        dim_head = 64,  # 每个头的维度
        heads = 8,  # 头的数量
        ff_mult = 4,  # FeedForward 层的倍数
        attn_dropout = 0.,  # 注意力层的 dropout 概率
        ff_dropout = 0.,  # FeedForward 层的 dropout 概率
        qk_rmsnorm = False,  # 是否对 QK 矩阵进行 RMS 归一化
        lora_r = 8,  # LoRA 模块的参数 r
        rotary_xpos_scale_base = 512,  # 旋转位置编码的基数
        flash_attn = False,  # 是否使用 Flash Attention
        finetune_scopes = tuple(),  # 微调的范围
        cross_entropy_ignore_index = 0  # 交叉熵损失的忽略索引
    ):
        super().__init__()
        # 初始化模型的各种属性
        self.dim = dim
        self.dim_head = dim_head
        self.heads = heads
        self.causal = causal
        self.num_tokens = num_tokens

        # 创建 token 的嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.layers = nn.ModuleList([])

        # 根据深度循环创建多个 Transformer Block
        for _ in range(depth):
            block = Residual(ParallelTransformerBlock(
                dim = dim,
                causal = causal,
                dim_head = dim_head,
                heads = heads,
                qk_rmsnorm = qk_rmsnorm,
                ff_mult = ff_mult,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                xpos_scale_base = rotary_xpos_scale_base,
                flash_attn = flash_attn
            ))

            self.layers.append(block)

        # 创建 LayerNorm 层
        self.norm = LayerNorm(dim)
        # 创建输出层,用于将模型输出转换为 token 的概率分布
        self.to_logits = nn.Linear(dim, num_tokens, bias=False)
        
        # 将输出层的权重与 token 嵌入层的权重共享
        self.to_logits.weight = self.token_emb.weight

        # 对 token 嵌入层的权重进行正态分布初始化
        nn.init.normal_(self.token_emb.weight, std=0.02)

        # 微调相关

        self.lora_r = lora_r
        self.finetune_modules = nn.ModuleDict({})

        # 根据微调范围添加微调参数
        for scope in finetune_scopes:
            self.add_finetune_params(scope)

        # 损失相关

        self.cross_entropy_ignore_index = cross_entropy_ignore_index

    # 定义 device 属性,用于获取模型参数所在的设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    # 设置模型中的 Dropout 层的概率
    def set_dropout(self, dropout):
        for module in self.layers.modules():
            if isinstance(module, nn.Dropout):
                module.p = dropout
        return self

    # 添加微调参数
    def add_finetune_params(self, scope, lora_r = None):
        assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
        dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device

        q_inner_dim = heads * dim_head
        kv_inner_dim = dim_head

        lora_modules = nn.ModuleList([])

        for _ in range(len(self.layers)):
            lora_modules.append(nn.ModuleList([
                LoRA(dim, q_inner_dim, r = r),   # queries
                LoRA(dim, kv_inner_dim, r = r),  # keys
                LoRA(dim, kv_inner_dim, r = r),  # values
                LoRA(q_inner_dim, dim, r = r)    # wo
            ]))

        self.finetune_modules[scope] = lora_modules.to(device)

    # 移除微调参数
    def remove_finetune_params(self, scope):
        assert scope in self.finetune_modules, f'finetune scope {scope} not found'
        return self.finetune_modules.pop(scope)

    # 禁用梯度计算
    @torch.no_grad()
    # 合并微调的 actor LORA 参数,用于多轮不同奖励模型的微调
    def merge_finetune_params(self, scope):
        """ in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """

        # 确保指定的微调范围存在
        assert scope in self.finetune_modules, f'finetune scope {scope} not found'

        # 弹出指定范围的 LORA 模块
        lora_modules = self.finetune_modules.pop(scope)

        # 遍历每个层和对应的 LORA 模块
        for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules):
            block = layer.fn

            # 获取融合的注意力和前馈权重
            fused_attn_ff_weight = block.fused_attn_ff_proj.weight
            attn_out_weight = block.attn_out.weight

            # 获取融合后的投影输出维度
            fused_proj_out_dim = fused_attn_ff_weight.shape[0]

            # 打包 Q、K、V 权重
            lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *')
            lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1]))

            # 重排 QKV 权重
            lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i')
            lora_o_weight = rearrange(lora_o.weight, 'i o -> o i')

            # 更新融合的注意力和前馈权重
            fused_attn_ff_weight.add_(lora_qkv_weight)
            attn_out_weight.add_(lora_o_weight)

    # 研究员首先训练 PALM 参数,然后进行微调

    # 获取 PALM 参数
    def palm_parameters(self):
        return set(self.parameters()) - set(self.finetune_modules.parameters())

    # 获取微调参数
    def finetune_parameters(self, scope = 'default'):
        assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found'
        return self.finetune_modules[scope].parameters()

    # 生成函数

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        seq_len,
        prompt = None,
        temperature = 1.,
        filter_logits_fn = top_k,
        filter_thres = 0.9,
        pad_value = 0.,
        eos_token = None,
        return_seq_without_prompt = True,
        use_tqdm = False,
        **kwargs
    ):
        # 如果没有指定提示,则随机生成一个
        if not exists(prompt):
            prompt = torch.randint(0, self.num_tokens, (1, 1))
            prompt = prompt.to(self.device)
            return_seq_without_prompt = False

        prompt, leading_dims = pack([prompt], '* n')

        n, out = prompt.shape[-1], prompt.clone()

        wrapper_fn = identity if not use_tqdm else tqdm
        sample_num_times = max(1, seq_len - prompt.shape[-1])

        for _ in wrapper_fn(range(sample_num_times)):
            logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
            logits, embeds = logits[:, -1], embeds[:, -1]

            if exists(filter_logits_fn):
                logits = filter_logits_fn(logits, thres = filter_thres)

            sample = gumbel_sample(logits, temperature = temperature, dim = -1)
            out, _ = pack([out, sample], 'b *')

            if exists(eos_token):
                is_eos_tokens = (out == eos_token)

                if is_eos_tokens.any(dim = -1).all():
                    # 掩盖掉 EOS 标记后的所有内容
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    out = out.masked_fill(mask, pad_value)
                    break

        out, = unpack(out, leading_dims, '* n')

        if not return_seq_without_prompt:
            return out

        return out[..., n:]

    # 前向传播函数
    def forward(
        self,
        x,
        return_loss = False,
        disable_lora = False,
        finetune_scope = None,
        extra_embed = None,
        return_only_embedding = False,
        return_logits_with_embedding = False
        ):
        # 如果需要返回损失,则将输入数据 x 切片,分别作为输入和标签
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]

        # 如果不是自回归模型,对编码器进行掩码处理
        # 将任何负数的标记视为需要屏蔽的标记 - 仅在非自回归情况下需要
        if not self.causal:
            mask = x >= 0
            x = x.masked_fill(~mask, 0)
        else:
            mask = None

        # 获取标记嵌入
        x = self.token_emb(x)

        # 如果存在额外的嵌入,则将其加到标记嵌入中
        if exists(extra_embed):
            x = x + extra_embed

        # 微调模块
        finetune_modules = tuple()
        if exists(finetune_scope) and not disable_lora:
            assert finetune_scope in self.finetune_modules
            finetune_modules = self.finetune_modules[finetune_scope]

        # 并行注意力 / 前馈块,传入微调 lora
        for layer, finetune_modules in zip_longest(self.layers, finetune_modules):
            x = layer(x, mask = mask, finetune_modules = finetune_modules)

        # 最终规范化
        embeds = self.norm(x)

        # 如果只需要返回嵌入,则直接返回嵌入
        if return_only_embedding:
            return embeds

        # 转换为逻辑值
        logits = self.to_logits(embeds)

        # 返回结果,根据需要返回逻辑值和嵌入或仅逻辑值
        ret = (logits, embeds) if return_logits_with_embedding else logits

        # 如果不需要返回损失,则直接返回结果
        if not return_loss:
            return ret

        # 重新排列逻辑值的维度,以便计算交叉熵损失
        logits = rearrange(logits, 'b n c -> b c n')
        return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\ppo.py

import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange

from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque

import torch
from torch import nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.optimizer import get_optimizer
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator

from accelerate import Accelerator

# actor critic - PaLM with lora

PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
    'actions',
    'sequence',
    'mask',
    'prompt_mask',
    'action_logits',
    'values'
])

@beartype
class ActorCritic(nn.Module):
    def __init__(
        self,
        palm: PaLM,
        critic_palm: Optional[PaLM] = None,
        pooled_values = False,
        actor_lora = True,
        critic_lora = True,
        actor_lora_r = 8,
        critic_lora_r = 8,
        actor_lora_scope = 'actor',
        critic_lora_scope = 'critic',
        actor_dropout = 0.,
        critic_dropout = 0.
    ):
        super().__init__()
        self.actor_palm = palm

        self.critic_palm = critic_palm

        if not exists(self.critic_palm):
            self.critic_palm = copy.deepcopy(palm)

        self.actor_palm.set_dropout(actor_dropout)
        self.critic_palm.set_dropout(critic_dropout)

        self.actor_lora = actor_lora
        self.critic_lora = critic_lora

        self.actor_lora_scope = actor_lora_scope if actor_lora else None
        self.critic_lora_scope = critic_lora_scope if critic_lora else None

        if self.actor_lora:
            self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)

        if self.critic_lora:
            self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)

        self.pooled_values = pooled_values
        self.value_head = nn.Sequential(
            nn.Linear(palm.dim, 1),
            Rearrange('... 1 -> ...')
        )

        nn.init.zeros_(self.value_head[0].bias)
        nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))

    def actor_parameters(self):
        # 返回 actor 参数,如果不使用 lora,则返回 actor_palm 的参数
        if not self.actor_lora:
            return self.actor_palm.parameters()

        return [
            *self.actor_palm.finetune_parameters(self.actor_lora_scope)
        ]

    def critic_parameters(self):
        # 返回 critic 参数,如果不使用 lora,则返回 critic_palm 和 value_head 的参数
        if not self.actor_lora:
            return [*self.critic_palm.parameters(), *self.value_head.parameters()]

        return [
            *self.critic_palm.finetune_parameters(self.critic_lora_scope),
            *self.value_head.parameters()
        ]

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        state,
        max_seq_len,
        eos_token = None,
        return_values = False,
        **kwargs
    # 生成动作序列,根据当前状态和最大序列长度
    actions = self.actor_palm.generate(
        max_seq_len,
        prompt = state,       
        eos_token = eos_token,     
        finetune_scope = self.actor_lora_scope,
        use_tqdm = True,
        **kwargs
    )

    # 将当前状态和生成的动作序列拼接在一起
    sequence = torch.cat((state, actions), dim = -1)
    action_len = actions.shape[-1]
    state_len = state.shape[-1]

    # 创建用于标记当前状态的掩码
    prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
    prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])

    # 创建用于标记动作的掩码
    action_mask = ~prompt_mask

    mask = None
    # 如果存在结束标记,创建用于标记结束标记的掩码
    if exists(eos_token):
        mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
        mask = F.pad(mask, (1, -1), value = True) # include eos token
        action_mask &= mask

    # 获取动作的logits和值
    action_logits, value = self.forward(
        sequence,
        mask = action_mask,
        return_values = return_values
    )        

    # 返回动作和值的对象
    return PPOActionCriticReturn(
        actions,
        sequence,
        mask,
        prompt_mask,
        action_logits,
        value
    )

def forward(
    self,
    x,
    mask = None,
    return_values = True
):
    # 获取动作的logits
    action_logits = self.actor_palm(
        x,
        finetune_scope = self.actor_lora_scope
    )

    # 如果不需要返回值,直接返回动作logits
    if not return_values:
        return action_logits, None

    # 获取评论者的嵌入
    critic_embeds = self.critic_palm(
        x,
        return_only_embedding = True,
        finetune_scope = self.critic_lora_scope
    )

    # 如果使用池化值,计算平均值
    if self.pooled_values:
        critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
        critic_embeds = masked_mean(critic_embeds, mask, dim = 1)

    # 获取值
    values = self.value_head(critic_embeds)

    # 返回动作logits和值
    return action_logits, values
# 定义一个命名元组 Memory,包含了序列、提示掩码、掩码、动作概率、动作对数概率、奖励和价值
Memory = namedtuple('Memory', [
    'sequence',
    'prompt_mask',
    'mask',
    'action_prob',
    'action_log_prob',
    'reward',
    'value'
])

# ExperienceDataset 类,继承自 Dataset 类,用于处理经验数据集
class ExperienceDataset(Dataset):
    def __init__(
        self,
        data: List[torch.Tensor],  # 接受一个包含 torch.Tensor 的列表作为数据
        device = None  # 设备参数,默认为 None
    ):
        super().__init__()
        self.data = data  # 存储数据
        self.device = device  # 存储设备信息

    def __len__(self):
        return self.data[0].shape[0]  # 返回数据的第一个维度大小

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind].to(self.device), self.data))  # 返回指定索引的数据,并将其移动到指定设备上

# 创建数据加载器函数,接受数据、批量大小、是否打乱数据、设备等参数
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
    ds = ExperienceDataset(data, device = device)  # 创建 ExperienceDataset 实例
    return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)  # 返回 DataLoader 实例

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 对张量进行归一化处理
def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
    dim = default(dim, tuple(range(t.ndim)))  # 获取维度信息
    kwargs = dict(dim = dim, keepdim = True)

    mean = masked_mean(t, mask = mask, **kwargs)  # 计算均值
    mean_centered = t - mean  # 中心化
    var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)  # 计算方差

    return mean_centered * var.clamp(min = eps).rsqrt()  # 返回归一化后的结果

# 对序列进行固定填充
def pad_sequence_fixed(sequences, *args, **kwargs):
    first_el = sequences[0]  # 获取第一个元素
    has_no_dimension = first_el.ndim == 0  # 判断是否没有维度

    # 如果没有维度,添加一个维度
    if has_no_dimension:
        sequences = tuple(map(lambda t: t[None], sequences))

    out = pad_sequence(sequences, *args, **kwargs)  # 使用 pad_sequence 进行填充

    if has_no_dimension:
        out = rearrange(out, '... 1 -> ...')  # 重新排列维度

    return out

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算对数概率
def log_prob(prob, indices):
    assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
    return log(prob.gather(-1, indices[..., None])).squeeze(-1)

# 对张量进行移位
def shift(t, value = 0, shift = 1, dim = -1):
    zeros = (0, 0) * (-dim - 1)
    return F.pad(t, (*zeros, shift, -shift), value = value)

# 计算掩码熵
def masked_entropy(prob, dim = -1, mask = None):
    entropies = (prob * log(prob)).sum(dim = -1)
    return masked_mean(entropies, mask = mask).mean()

# 计算掩码 KL 散度
def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False):
    """
    need to account for variable sequence lengths, therefore not using the built-in functional version
    """
    kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)
    loss = masked_mean(kl_divs, mask)

    if reduce_batch:
        return loss.mean()

    return loss

# 计算截断值损失
def clipped_value_loss(values, rewards, old_values, clip):
    value_clipped = old_values + (values - old_values).clamp(-clip, clip)
    value_loss_1 = (value_clipped.flatten() - rewards) ** 2
    value_loss_2 = (values.flatten() - rewards) ** 2
    return torch.mean(torch.max(value_loss_1, value_loss_2))

# RLHFTrainer 类,继承自 nn.Module
class RLHFTrainer(nn.Module):
    # 初始化函数,设置模型的各种参数和超参数
    def __init__(
        self,
        *,
        prompts: Optional[List[str]] = None,  # 提示语列表
        prompts_path: Optional[str] = None,  # 提示语文件路径
        prompt_token_ids: Optional[torch.Tensor] = None,  # 提示语的token ids
        tokenizer: Callable = None,  # 分词器
        palm: PaLM,  # 主模型
        reward_model: RewardModel,  # 奖励模型
        critic_palm: Optional[PaLM] = None,  # 评论者模型
        actor_critic: Optional[ActorCritic] = None,  # 演员评论者模型
        actor_lr = 1e-4,  # 演员学习率
        critic_lr = 1e-4,  # 评论者学习率
        actor_wd = 0.,  # 演员权重衰减
        critic_wd = 0.,  # 评论者权重衰减
        actor_adam_eps = 1e-7,  # 演员Adam优化器epsilon
        critic_adam_eps = 1e-7,  # 评论者Adam优化器epsilon
        actor_lora = True,  # 演员是否使用LoRA
        critic_lora = True,  # 评论者是否使用LoRA
        actor_lora_r = 8,  # 演员LoRA半径
        critic_lora_r = 8,  # 评论者LoRA半径
        critic_pooled_values = True,  # 评论者是否使用池化值
        actor_dropout = 0.,  # 演员Dropout
        critic_dropout = 0.,  # 评论者Dropout
        betas = (0.9, 0.999),  # Adam优化器betas
        max_norm = None,  # 梯度裁剪最大范数
        eps_clip = 0.2,  # PPO算法epsilon裁剪
        value_clip = 0.4,  # 值函数裁剪
        beta_s = .01,  # beta_s参数
        pad_value = 0.,  # token填充值
        minibatch_size = 16,  # 小批量大小
        epochs = 1,  # 训练轮数
        kl_div_loss_weight = 0.1,  # KL散度损失权重
        accelerate_kwargs: dict = {},  # 加速器参数
        use_lion = False  # 是否使用LION
    ):
        # 调用父类初始化函数
        super().__init__()

        # 初始化加速器
        self.accelerate = Accelerator(**accelerate_kwargs)

        # 处理提示语到token ids的转换
        assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1

        if exists(prompts_path):
            path = Path(prompts_path)
            prompts = path.read_text().split('\n')

        if exists(prompts):
            assert len(prompts) > 0, 'no prompts'
            assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
            prompt_token_ids = tokenizer(prompts)

        self.pad_value = pad_value  # token填充值
        self.num_prompts = prompt_token_ids.shape[0]  # 提示语数量
        self.register_buffer('prompt_token_ids', prompt_token_ids)  # 注册提示语token ids

        # 初始化模型
        self.palm = palm

        if not exists(actor_critic):
            actor_critic = ActorCritic(
                palm = palm,
                critic_palm = critic_palm,
                actor_lora = actor_lora,
                critic_lora = critic_lora,
                actor_lora_r = actor_lora_r,
                critic_lora_r = critic_lora_r,
                pooled_values = critic_pooled_values,
                actor_dropout = actor_dropout,
                critic_dropout = critic_dropout
            ).to(palm.device)

        self.actor_critic = actor_critic  # 演员评论者模型

        self.reward_model = reward_model.eval()  # 奖励模型

        # 训练超参数
        self.epochs = epochs
        self.minibatch_size = minibatch_size
        self.max_norm = max_norm
        self.kl_div_loss_weight = kl_div_loss_weight

        # 优化器
        self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion)
        self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion)

        # PPO算法超参数
        self.eps_clip = eps_clip
        self.value_clip = value_clip
        self.beta_s = beta_s

        # 准备加速器
        (
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        ) = self.accelerate.prepare(
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        )

    # 打印函数
    def print(self, msg):
        return self.accelerate.print(msg)

    # 保存模型参数
    def save(self, filepath = './checkpoint.pt'):
        torch.save(self.actor_critic.state_dict(), filepath)

    # 加载模型参数
    def load(self, filepath = './checkpoint.pt'):
        state_dict = torch.load(filepath)
        self.actor_critic.load_state_dict(state_dict)

    # 设备属性
    @property
    def device(self):
        return self.accelerate.device

    # 禁用梯度计算
    @torch.no_grad()
    # 定义一个生成器函数,用于生成文本序列
    def generate(
        self,
        max_seq_len,
        *args,
        prompt,
        num_samples = 4,  # 每个提示生成4个样本,选择具有最高奖励的一个
        **kwargs
    ):
        # 断言只有一个提示允许在同一时间
        assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
        # 复制提示以匹配生成的样本数量
        prompt = repeat(prompt, 'n -> b n', b = num_samples)

        # 获取未加速的 actor_critic 模型
        actor_critic = self.accelerate.unwrap_model(self.actor_critic)
        # 获取未加速的 reward_model 模型
        reward_model = self.accelerate.unwrap_model(self.reward_model)

        # 设置 actor_critic 模型为评估模式
        actor_critic.eval()

        # 生成动作、序列、掩码、提示掩码、动作概率等信息
        (
            actions,
            sequences,
            mask,
            prompt_mask,
            action_logits,
            _
        ) = actor_critic.generate(
            prompt,
            *args,
            max_seq_len = max_seq_len,
            return_values = False,
            **kwargs
        )

        # 使用奖励模型计算奖励
        rewards = reward_model(
            sequences,
            prompt_mask = prompt_mask,
            mask = mask,
            sample = True
        )

        # 选择具有最高奖励的序列索引
        best_sequence_index = rewards.topk(1, dim = -1).indices

        # 获取最佳序列
        best_sequence = sequences[best_sequence_index]
        # 重新排列最佳序列的维度
        best_sequence = rearrange(best_sequence, '1 ... -> ...')

        # 返回最佳序列
        return best_sequence

    # 定义一个学习函数,用于学习记忆
    def learn(
        self,
        memories: Deque[Memory]
    
    # 定义一个训练函数,用于训练模型
    def train(
        self,
        num_episodes = 50000,
        max_timesteps = 500,
        update_timesteps = 5000,
        max_batch_size = 16,
        max_seq_len = 2048,
        eos_token = None,
        temperature = 1.
        ):
        # 获取当前环境设备
        device = self.device

        # 初始化时间步长和记忆队列
        time = 0
        memories = deque([])

        # 循环执行一定数量的 episodes
        for eps in tqdm(range(num_episodes), desc='episodes'):
            # 在每个 episode 中执行一定数量的时间步长
            for timestep in range(max_timesteps):
                time += 1

                # 选择一组随机状态(提示)并获取动作(从 palm 中采样的序列以及动作概率)
                # 使用奖励模型计算奖励并存储

                # 随机选择一个提示的索引
                rand_prompt_index = randrange(0, self.num_prompts)

                # 获取状态(提示)的 token ID
                state = self.prompt_token_ids[rand_prompt_index]

                # 去除状态中的填充
                state_mask = state != self.pad_value
                state = state[state_mask]

                # 生成预测序列
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len=max_seq_len,
                    eos_token=eos_token,
                    temperature=temperature,
                    return_values=True
                )
                action_logits = shift(action_logits, shift=1, dim=-2)  # 需要沿着序列维度移动 1,因为动作从最后一个提示(状态)标记开始

                action_prob = action_logits.softmax(dim=-1)

                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]

                actions = rearrange(actions, '1 ... -> ...')

                # 使用经过监督训练的奖励模型获取奖励
                sequence = torch.cat((state, actions), dim=0)

                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device=device) < prompt_length

                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = default(mask, lambda: torch.ones(sequence.shape, dtype=torch.bool, device=device))

                reward = self.reward_model(
                    sequence,
                    prompt_mask=prompt_mask,
                    mask=mask,
                    sample=True
                )

                detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')

                # 存储用于学习的记忆
                memories.append(Memory(*map(detach_to_cpu_, (
                    sequence,
                    prompt_mask,
                    mask,
                    action_prob,
                    action_log_prob,
                    reward,
                    value
                )))

                # 从存储的记忆中学习
                if time % update_timesteps == 0:
                    self.learn(memories)
                    memories.clear()

        print('rlhf training complete')

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\reward.py

# 导入必要的库
import copy
from pathlib import Path

from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from palm_rlhf_pytorch.utils import masked_mean, gumbel_sample
from palm_rlhf_pytorch.palm import PaLM

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 奖励模型 - 带有标量头的 PaLM

@beartype
class RewardModel(nn.Module):
    def __init__(
        self,
        palm: PaLM,
        dropout = 0.1,
        num_binned_output = 0.,
        use_lora = True,
        lora_r = 8,
        reward_lora_scope = 'reward',
    ):
        super().__init__()

        # 深拷贝传入的 PaLM 模型
        self.palm = copy.deepcopy(palm)
        self.palm.set_dropout(dropout)

        # 根据 use_lora 参数决定是否使用 LORA
        self.reward_lora_scope = reward_lora_scope if use_lora else None

        # 如果启用了 LORA,则为奖励模型添加微调参数
        if exists(self.reward_lora_scope):
            self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r)

        dim = palm.dim

        # 判断是否需要输出多个分箱
        self.binned_output = num_binned_output > 1

        # 初始化提示和响应的嵌入向量
        self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
        self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))

        # 根据是否需要多个分箱选择不同的输出层
        if self.binned_output:
            self.to_pred = nn.Linear(dim, num_binned_output)
        else:
            self.to_pred = nn.Sequential(
                nn.Linear(dim, 1, bias = False),
                Rearrange('... 1 -> ...')
            )

    # 加载模型参数
    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    # 获取需要微调的参数
    def finetune_parameters(self):
        return [
            *self.to_pred.parameters(),
            *(self.palm.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else self.palm.parameters())
        ]

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
        prompt_lengths = None,
        labels = None,
        sample = False,
        sample_temperature = 1.,
        disable_lora = False
    ):

        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # 从提示长度中推���提示掩码
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device = x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # 奖励模型应该了解哪部分是提示,哪部分是响应

        extra_embed = None

        if exists(prompt_mask):
            extra_embed = torch.where(
                rearrange(prompt_mask, 'b n -> b n 1'),
                self.prompt_embed,
                self.response_embed
            )

        # 从 PaLM 中获取嵌入向量
        embeds = self.palm(
            x,
            extra_embed = extra_embed,
            return_only_embedding = True,
            disable_lora = disable_lora,
            finetune_scope = self.reward_lora_scope
        )

        # 对嵌入向量进行平均池化
        pooled = masked_mean(embeds, mask, dim = 1)
        pred = self.to_pred(pooled)

        # 如果需要采样并且输出为多个分箱,则对输出进行 Gumbel 采样
        if sample and self.binned_output:
            assert not exists(labels)
            pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)

        # 如果标签不存在,则直接返回预测值
        if not exists(labels):
            return pred

        # 如果输出不是多个分箱,则计算均方误差损失
        if not self.binned_output:
            return F.mse_loss(pred, labels)

        # 如果输出为多个分箱,则计算交叉熵损失
        return F.cross_entropy(pred, labels)

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\utils.py

# 导入 math、torch 模块,以及从 torch 模块中导入 einsum、nn 和 nn.functional 模块
import math
import torch
from torch import einsum, nn
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 装饰器函数

# 评估装饰器函数,用于在执行函数时将模型设置为评估模式
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 张量辅助函数

# 对张量取对数,避免取对数时出现负无穷
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算带掩码的平均值,如果没有掩码则直接计算平均值
def masked_mean(seq, mask = None, dim = 1, keepdim = False):
    if not exists(mask):
        return seq.mean(dim = dim)

    if seq.ndim == 3:
        mask = rearrange(mask, 'b n -> b n 1')

    masked_seq = seq.masked_fill(~mask, 0.)
    numer = masked_seq.sum(dim = dim, keepdim = keepdim)
    denom = mask.sum(dim = dim, keepdim = keepdim)

    masked_mean = numer / denom.clamp(min = 1e-3)
    masked_mean = masked_mean.masked_fill(denom == 0, 0.)
    return masked_mean

# 采样辅助函数

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 使用 Gumbel 噪声对张量进行采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# Top-p 采样方法
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# Top-k 采样方法
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

.\lucidrains\PaLM-rlhf-pytorch\palm_rlhf_pytorch\__init__.py

# 从 palm_rlhf_pytorch.palm 模块中导入 PaLM 类
from palm_rlhf_pytorch.palm import PaLM
# 从 palm_rlhf_pytorch.reward 模块中导入 RewardModel 类
from palm_rlhf_pytorch.reward import RewardModel
# 从 palm_rlhf_pytorch.ppo 模块中导入 RLHFTrainer, ActorCritic 类
from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic

official chatgpt blogpost

PaLM + RLHF - Pytorch (wip)

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Maybe I'll add retrieval functionality too, à la RETRO

If you are interested in replicating something like ChatGPT out in the open, please consider joining Laion Join us on Discord

Potential successor: Direct Preference Optimization - all the code in this repo becomes ~ binary cross entropy loss, < 5 loc. So much for Reward models and PPO

FAQ

  • Does this contain a model for inference?

There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point.

Community

CarperAI had been working on an RLHF framework for large language models for many months prior to the release of ChatGPT.

Yannic Kilcher is also working on an open sourced implementation

AI Coffeebreak w/ Letitia | Code Emporium | Code Emporium Part 2

Appreciation

Install

$ pip install palm-rlhf-pytorch

Usage

First train PaLM, like any other autoregressive transformer

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn = True # https://arxiv.org/abs/2205.14135
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
loss.backward()

# after much training, you can now generate sequences

generated = palm.generate(2048) # (1, 2048)

Then train your reward model, with the curated human feedback. In the original paper, they could not get reward model to be finetuned from a pretrained transformer without overfitting, but I gave the option to finetune with LoRA anyways, since it is still open research.

import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()

# after much training

reward = reward_model(seq, prompt_mask = prompt_mask)

Then you will pass your transformer and the rewards model to the RLHFTrainer

import torch
from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer

# load your pretrained palm

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12
).cuda()

palm.load('./path/to/pretrained/palm.pt')

# load your pretrained reward model

reward_model = RewardModel(
    palm,
    num_binned_output = 5
).cuda()

reward_model.load('./path/to/pretrained/reward_model.pt')

# ready your list of prompts for reinforcement learning

prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts

# pass it all to the trainer and train

trainer = RLHFTrainer(
    palm = palm,
    reward_model = reward_model,
    prompt_token_ids = prompts
)

trainer.train(num_episodes = 50000)

# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one

answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)

Todo

  • clone base transformer with separate lora for critic

  • also allow for non-LoRA based finetuning

  • redo normalize to be able to have a masked version, not sure if anyone will ever use per token rewards / values, but good practice to implement

  • equip with the best attention

  • add Hugging Face accelerate and test out wandb instrumentation

  • search literature to figure out what is the latest SOTA for PPO, assuming RL field is still making progress.

  • test the system using a pretrained sentiment network as reward model

  • write the memory in PPO to memmapped numpy file

  • get sampling with variable lengthed prompts working, even if it is not needed given bottleneck is human feedback

  • allow for finetuning penultimate N layers only in either actor or critic, assuming if pretrained

  • incorporate some learning points from Sparrow, given Letitia's video

  • simple web interface with django + htmx for collecting human feedback

  • consider RLAIF

Citations

@article{Stiennon2020LearningTS,
    title   = {Learning to summarize from human feedback},
    author  = {Nisan Stiennon and Long Ouyang and Jeff Wu and Daniel M. Ziegler and Ryan J. Lowe and Chelsea Voss and Alec Radford and Dario Amodei and Paul Christiano},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2009.01325}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
@article{Hu2021LoRALA,
    title   = {LoRA: Low-Rank Adaptation of Large Language Models},
    author  = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.09685}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\PaLM-rlhf-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'PaLM-rlhf-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/PaLM-rlhf-pytorch',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'reinforcement learning',
    'human feedback'
  ],
  install_requires=[  # 安装依赖
    'accelerate',
    'beartype',
    'einops>=0.6',
    'lion-pytorch',
    'torch>=1.6',
    'tqdm'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\PaLM-rlhf-pytorch\train.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np

import torch
from lion_pytorch import Lion
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from palm_rlhf_pytorch import PaLM
from accelerate import Accelerator

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))


# 初始化加速器
accelerator = Accelerator()
device = accelerator.device

# 实例化 PaLM 模型
model = PaLM(
    num_tokens=256,
    dim=512,
    depth=8,
    flash_attn=True
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 初始化优化器
optim = Lion(model.palm_parameters(), lr = LEARNING_RATE)

# 准备模型、优化器、训练集加载器和验证集加载器
model, optim, train_loader, val_loader = accelerator.prepare(
    model, optim, train_loader, val_loader
)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    accelerator.print(f"training loss: {loss.item()}")
    accelerator.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            accelerator.print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        accelerator.print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(GENERATE_LENGTH, inp[None, ...])
        output_str = decode_tokens(sample[0])
        accelerator.print(output_str, "\n")

.\lucidrains\panoptic-transformer\panoptic_transformer\data.py

# 导入所需的库
from pathlib import Path
from random import choice
from PIL import Image
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision import transforms as T

# 定义一个循环生成器函数,用于循环遍历数据集
def cycle(dl):
    while True:
        for el in dl:
            yield el

# 定义 PathfinderXDataset 类,继承自 Dataset 类
class PathfinderXDataset(Dataset):
    def __init__(
        self,
        folder,
        augment = False
    ):
        super().__init__()
        # 获取文件夹中所有的 .npy 文件
        metadata_files = [*Path(folder).glob(f'**/*.npy')]
        # 断言确保找到了至少一个 metadata 文件
        assert len(metadata_files) > 0, 'not able to find more than 1 metadata file'

        # 获取第一个 metadata 文件
        metadata_file = metadata_files[0]
        # 加载 metadata 文件
        metadata = np.load(str(metadata_file))
        # 获取 metadata 文件的父目录
        root_path = metadata_file.parents[1]

        self.augment = augment
        # 将数据集的路径和标签存储为元组的列表
        self.data = [(str(root_path / m[0] / m[1]), int(m[3])) for m in metadata]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, ind):
        # 获取指定索引的路径和标签
        path, label = self.data[ind]
        # 打开图像文件
        img = Image.open(path)

        # 对图像进行数据增强处理
        img = T.Compose([
            T.RandomHorizontalFlip() if self.augment else nn.Identity(),
            T.RandomVerticalFlip() if self.augment else nn.Identity(),
            T.PILToTensor()
        ])(img)

        # 将标签转换为 torch 张量
        label = torch.tensor(label, dtype = torch.float32)

        if self.augment:
            # 随机选择旋转角度
            rand_rotate = [0, 90, 180, 270]
            img = T.functional.rotate(img, choice(rand_rotate))
            # 随机选择填充方式
            rand_padding = [(0, 0, 0, 0), (1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1)]
            img = F.pad(img, choice(rand_padding))

        return img.float(), label

# 获取训练和验证数据加载器函数
def get_dataloaders(
    data_path,
    *,
    augment = True,
    frac_valids = 0.05,
    batch_size
):
    # 创建 PathfinderXDataset 实例
    ds = PathfinderXDataset(data_path, augment = augment)

    total_samples = len(ds)
    # 计算验证集样本数量
    num_valid = int(frac_valids * total_samples)
    # 计算训练集样本数量
    num_train = total_samples - num_valid

    print(f'training with {num_train} samples and validating with {num_valid} samples')

    # 随机划分数据集为训练集和验证集
    train_ds, valid_ds = random_split(ds, [num_train, num_valid])

    # 创建训练数据加载器和验证数据加载器
    train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
    valid_dl = DataLoader(valid_ds, batch_size = batch_size, shuffle = True)

    return cycle(train_dl), cycle(valid_dl)

.\lucidrains\panoptic-transformer\panoptic_transformer\panoptic_transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch.nn.functional 中导入 F 模块

# 定义一个名为 Attention 的类,继承自 nn.Module 类
class Attention(nn.Module):
    # 初始化函数,接受参数 dim、dim_head 和 heads
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        # 计算内部维度
        inner_dim = heads * dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5
        # 头数
        self.heads = heads

        # 定义一个线性层,用于将输入转换为查询向量
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 定义一个线性层,用于将输入转换为键值对
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        # 定义一个线性层,用于将输出转换为指定维度
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 将输入 x 转换为查询向量 q,键向量 k 和值向量 v
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        # 重排查询向量 q 的维度
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
        # 缩放查询向量 q
        q = q * self.scale

        # 计算相似度矩阵 sim
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 对相似度矩阵进行 softmax 操作,得到注意力矩阵 attn
        attn = sim.softmax(dim = -1)

        # 根据注意力矩阵计算输出 out
        out = einsum('b h i j, b j d -> b h i d', attn , v)

        # 重排输出 out 的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回转换后的输出
        return self.to_out(out)

# 定义一个名为 PanopticTransformer 的类,继承自 nn.Module 类
class PanopticTransformer(nn.Module):
    # 初始化函数,接受参数 dim、dim_head 和 heads
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 直接返回输入 x,未进行任何操作
        return x