强化学习基础:拒绝采样

4 阅读1分钟

拒绝采样

拒绝采样(Rejection Sampling) 是一种通过从一个容易采样的分布(建议分布)中抽取样本,并按一定比例随机拒绝,从而获得目标分布(通常是难以直接采样的复杂分布)样本的方法。

核心原理

  1. 目标分布 q(x)q(x):我们想要从中采样的分布(如beta分布,正态分布)。但直接采样q(x)q(x)很难,所以找一个容易采样的建议分布 p(x)p(x)
  2. 建议分布 p(x)p(x):一个容易采样的分布(如均匀分布或正态分布),且必须满足存在一个常数 MM,使得对所有 xx 都有 Mp(x)q(x)M\cdot p(x)\ge q(x)。()
  3. 采样过程:从 p(x)p(x) 中抽取一个样本 x0x_{0}
    1. 从均匀分布 U(0,1)U(0,1) 中抽取一个随机数 uu
    2. 如果 uq(x0)Mp(x0)u\le \frac{q(x_{0})}{M\cdot p(x_{0})},则接受 x0x_{0}
    3. 否则拒绝并重新尝试。

直观几何解释

想象你想要在一条形状复杂的曲线(目标分布 p(x)p(x))下均匀地撒点。  • 覆盖区域:由于 q(x)q(x) 很难直接撒点,我们先找一个更高的、简单的箱子(建议分布 Mp(x)M\cdot p(x))把整个 q(x)q(x) 罩住。

• 均匀撒点:我们在整个“箱子”范围内均匀地撒下随机点。

• 筛选过滤:如果一个点落在了曲线 q(x)q(x) 的下方,我们就保留它;如果落在曲线上面、箱子里面,我们就扔掉它。

• 结果:最终留下来的点在水平方向上的疏密程度,会完美契合曲线 q(x)q(x) 的高低变化。曲线高的地方,落下的点就多,采样概率就大。

基于一维的样本简单解释 • 水平方向(X 轴):从建议分布p(x)p(x)中随机选一个 x(对应采样点的水平位置);

• 垂直方向(Y 轴):在00M×p(x)M×p(x)这个区间内,随机选一个 YY 值(对应采样点的垂直位置);

• 判定规则:如果这个随机选的 YY 值 ≤ 目标分布在 xx 处的 PDF 值p(x)p(x),就 “接受” 这个采样点;否则 “拒绝”。

对于垂直方向的随机值,一般拆解成三个部分

  1. 均匀分布U(0,1):把 0-1 的随机数映射到「0 到M×p(x)M \times p(x)」的区间
  2. 提议分布p(x):直接根据模型或者pdf计算
  3. 常数M:保证M×p(x)M \times p(x)可以罩住q(x)q(x)M=argmaxx(q(x)/p(x))M = argmax_{x}(q(x) / p(x))

最终的y轴随机值就是Y=u×M×p(s)Y=u \times M \times p(s)

• 接受采样:Y=u×M×p(s)<=q(x)Y=u \times M \times p(s) <= q(x),采样点落在目标分布下方,加入

• 拒绝采样:Y=u×M×p(s)>q(x)Y=u \times M \times p(s) > q(x),采样点落在目标分布上方,丢弃

数据概率证明

抽到 xx 的概率:正比于建议分布 g(x)g(x)。 接受该点 xx 的概率:算法规定接受率为 f(x)Mg(x)\frac{f(x)}{M\cdot g(x)}。 最终样本点的分布: P(样本为 x)g(x)×f(x)Mg(x)=f(x)MP(\text{样本为\ }x)\propto g(x)\times \frac{f(x)}{M\cdot g(x)}=\frac{f(x)}{M} 因为 MM 是一个常数,所以最终采样到的 xx 的概率密度正比于 f(x)f(x)。 

证明也非常的简单,

拒绝采样举例

代码

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.stats import beta, norm
from abc import ABC, abstractmethod

# ====================== 1. 抽象层(设计模式核心) ======================
class DistributionStrategy(ABC):
    """分布策略抽象接口:所有概率分布必须实现该接口"""
    @abstractmethod
    def pdf(self, x):
        """计算x处的概率密度值"""
        pass

class SamplerInterface(ABC):
    """采样器抽象接口:所有采样算法必须实现该接口"""
    @abstractmethod
    def sample_batch(self):
        """生成一批采样点"""
        pass

    @abstractmethod
    def get_sampling_data(self):
        """返回格式化的采样数据(供可视化使用)"""
        pass

    @property
    @abstractmethod
    def is_finished(self):
        """判断是否完成目标采样数"""
        pass

# ====================== 2. 策略实现层(具体分布) ======================
class BetaDistribution(DistributionStrategy):
    """Beta分布策略实现"""
    def __init__(self, a=2, b=5):
        self.a = a
        self.b = b

    def pdf(self, x):
        return beta.pdf(x, self.a, self.b)

class NormDistribution(DistributionStrategy):
    """正态分布策略实现(修复注释和参数使用)"""
    def __init__(self, u=0, s=1):
        self.u = u  # 均值
        self.s = s  # 标准差

    def pdf(self, x):
        # 修复:使用自身的均值和标准差参数
        return norm.pdf(x, loc=self.u, scale=self.s)

class UniformDistribution(DistributionStrategy):
    """均匀分布策略实现(建议分布)- 修复数组输入支持"""
    def __init__(self, low=0, high=1):
        self.low = low
        self.high = high
        self.pdf_value = 1.0 / (self.high - self.low)  # 预计算均匀分布的PDF值(标量)

    def pdf(self, x):
        """
        支持标量/数组输入的均匀分布PDF计算
        :param x: 标量或NumPy数组
        :return: 对应x的PDF值(标量或数组)
        """
        # 将输入转为NumPy数组(兼容标量输入)
        x_arr = np.asarray(x)
        # 向量化条件判断:生成布尔数组(True表示x在区间内)
        in_range = (self.low <= x_arr) & (x_arr <= self.high)
        # 初始化结果数组:默认0,区间内的位置设为pdf_value
        result = np.zeros_like(x_arr, dtype=np.float64)
        result[in_range] = self.pdf_value
        # 如果输入是标量,返回标量;否则返回数组
        return result.item() if np.isscalar(x) else result

# ====================== 3. 采样器实现层(依赖分布策略) ======================
class RejectionSampler(SamplerInterface):
    """拒绝采样器:实现采样器接口,依赖分布策略接口(而非具体类)"""
    def __init__(self, target_dist: DistributionStrategy, proposal_dist: DistributionStrategy, config):
        self.target_dist = target_dist
        self.proposal_dist = proposal_dist
        self.config = config

        # 采样参数
        self.M = config["M"]
        self.target_accepted = config["target_accepted"]
        self.x_lim = config["x_lim"]

        # 采样过程数据
        self.all_proposed_x = []
        self.all_u_values = []
        self.acceptance_mask = []
        self.accepted_count = 0

    def get_dynamic_batch_size(self):
        """动态计算每帧生成点数(前慢后快)"""
        return 1 + self.accepted_count // 10  # 线性递增规则

    def sample_batch(self):
        """生成一批采样点(实现SamplerInterface接口)"""
        if self.is_finished:
            return

        batch_size = self.get_dynamic_batch_size()
        for _ in range(batch_size):
            # 1. 从建议分布采样x
            x_proposed = np.random.uniform(*self.x_lim)
            # 2. 生成0-1的判定值u
            u = np.random.uniform(0, 1)

            # 3. 计算接受概率(调用策略接口的pdf方法)
            proposal_pdf = self.proposal_dist.pdf(x_proposed)
            if proposal_pdf == 0:  # 避免除以0
                is_accepted = False
            else:
                acceptance_prob = self.target_dist.pdf(x_proposed) / (self.M * proposal_pdf)
                is_accepted = u <= acceptance_prob

            # 4. 记录数据
            self.all_proposed_x.append(x_proposed)
            self.all_u_values.append(u)
            self.acceptance_mask.append(is_accepted)

            if is_accepted:
                self.accepted_count += 1

    def get_sampling_data(self):
        """返回格式化采样数据(修复Y值计算逻辑)"""
        mask_array = np.array(self.acceptance_mask)
        all_x = np.array(self.all_proposed_x)
        all_u = np.array(self.all_u_values)

        # 核心修复:Y值 = u * M * 建议分布的PDF值(匹配拒绝采样的判定逻辑)
        proposal_pdf_vals = np.array([self.proposal_dist.pdf(x) for x in all_x])
        y_vals = all_u * self.M * proposal_pdf_vals

        return {
            "accepted_x": all_x[mask_array],
            "accepted_y": y_vals[mask_array],  # 使用正确的Y值
            "rejected_x": all_x[~mask_array],
            "rejected_y": y_vals[~mask_array],  # 使用正确的Y值
            "accepted_count": self.accepted_count,
            "total_proposed": len(self.all_proposed_x),
            "batch_size": self.get_dynamic_batch_size()
        }

    @property
    def is_finished(self):
        """判断是否完成采样(实现SamplerInterface接口)"""
        return self.accepted_count >= self.target_accepted

# ====================== 4. 可视化层(依赖采样器接口) ======================
class SamplingVisualizer:
    """可视化器:依赖SamplerInterface,支持任何实现该接口的采样器"""
    def __init__(self, config, sampler: SamplerInterface):
        self.config = config
        self.sampler = sampler
        self.fig, self.ax = self._init_figure()
        self.scatter_rejected, self.scatter_accepted = self._init_scatters()
        self._draw_static_elements()

    def _init_figure(self):
        """初始化画布"""
        plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False

        fig, ax = plt.subplots(figsize=self.config["fig_size"])
        ax.set_xlim(*self.config["x_lim"])
        ax.set_ylim(*self.config["y_lim"])
        ax.set_xlabel('x 值', fontsize=12)
        ax.set_ylabel('概率密度', fontsize=12)
        ax.grid(True, alpha=0.3)
        return fig, ax

    def _init_scatters(self):
        """初始化接受/拒绝点散点"""
        scatter_rejected = self.ax.scatter([], [], color='red', alpha=0.5, s=5, label='拒绝的点')
        scatter_accepted = self.ax.scatter([], [], color='black', alpha=0.8, s=8, label='接受的点')
        return scatter_rejected, scatter_accepted

    def _draw_static_elements(self):
        """绘制静态参考元素(修复标签错误)"""
        x_range = np.linspace(*self.config["x_lim"], 1000)
        # 绘制目标分布曲线(修复标签:匹配正态分布参数)
        target_pdf_vals = self.sampler.target_dist.pdf(x_range)
        self.ax.plot(
            x_range, target_pdf_vals,
            'blue', linewidth=2, label=f'目标分布 正态({self.sampler.target_dist.u},{self.sampler.target_dist.s})'
        )
        # 绘制M*建议分布曲线(而非单纯的M水平线,这是另一个关键错误)
        proposal_pdf_vals = self.sampler.proposal_dist.pdf(x_range)
        self.ax.plot(
            x_range, self.config["M"] * proposal_pdf_vals,
            'green', linestyle='--', linewidth=2, label=f'M * 建议分布 (M={self.config["M"]})'
        )
        self.ax.legend(loc='upper right')

    def _update_title(self, sampling_data):
        """更新动画标题"""
        efficiency = (sampling_data["accepted_count"] / sampling_data["total_proposed"] * 100) if sampling_data["total_proposed"] > 0 else 0
        batch_size = sampling_data["batch_size"]
        speed_label = "慢" if batch_size <= 5 else "中" if batch_size <= 20 else "快"

        self.ax.set_title(
            f'拒绝采样动态过程 | 接受数: {sampling_data["accepted_count"]}/{self.config["target_accepted"]} | 效率: {efficiency:.1f}% | 速度: {speed_label}',
            fontsize=14
        )

    def animation_init(self):
        """动画初始化"""
        self.scatter_rejected.set_offsets(np.empty((0, 2)))
        self.scatter_accepted.set_offsets(np.empty((0, 2)))
        return self.scatter_rejected, self.scatter_accepted

    def animation_update(self, frame):
        """动画更新(与具体采样器解耦)"""
        self.sampler.sample_batch()
        sampling_data = self.sampler.get_sampling_data()

        # 更新散点
        if len(sampling_data["accepted_x"]) > 0:
            self.scatter_accepted.set_offsets(np.column_stack((sampling_data["accepted_x"], sampling_data["accepted_y"])))
        if len(sampling_data["rejected_x"]) > 0:
            self.scatter_rejected.set_offsets(np.column_stack((sampling_data["rejected_x"], sampling_data["rejected_y"])))

        # 更新标题
        self._update_title(sampling_data)
        return self.scatter_rejected, self.scatter_accepted

    def run_animation(self):
        """启动动画"""
        ani = FuncAnimation(
            self.fig, self.animation_update, init_func=self.animation_init,
            frames=self.config["max_frames"],
            interval=self.config["frame_interval"],
            blit=True,
            repeat=False
        )
        plt.show()
        return ani

# ====================== 5. 工厂层(简化实例创建) ======================
class SamplerFactory:
    """采样器工厂:封装复杂的初始化逻辑"""
    @staticmethod
    def create_rejection_sampler(config):
        """创建拒绝采样器实例(默认正态分布+均匀分布)"""
        target_dist = NormDistribution(u=0, s=1)
        proposal_dist = UniformDistribution(low=-4, high=4)
        return RejectionSampler(target_dist, proposal_dist, config)

# ====================== 6. 配置+主程序 ======================
def setup_global_config():
    """全局配置(调整y_lim为合理值)"""
    return {
        "M": 10,
        "target_accepted": 1000,
        "frame_interval": 50,
        "max_frames": 2000,
        "fig_size": (12, 8),
        "x_lim": (-4, 4),
        "y_lim": (0, 1.2),  # 增大y_lim后也能正确显示
    }

def main():
    # 1. 初始化配置
    config = setup_global_config()

    # 2. 工厂创建采样器(可替换为其他采样器)
    sampler = SamplerFactory.create_rejection_sampler(config)

    # 3. 初始化可视化器(依赖采样器接口,无需修改可视化器代码)
    visualizer = SamplingVisualizer(config, sampler)

    # 4. 运行动画
    visualizer.run_animation()

    # 5. 输出统计
    sampling_data = sampler.get_sampling_data()
    final_efficiency = (sampling_data["accepted_count"] / sampling_data["total_proposed"] * 100) if sampling_data["total_proposed"] > 0 else 0
    print("\n=== 采样完成 ===")
    print(f"目标接受样本数: {config['target_accepted']}")
    print(f"实际接受样本数: {sampling_data['accepted_count']}")
    print(f"总提议样本数: {sampling_data['total_proposed']}")
    print(f"最终采样效率: {final_efficiency:.2f}%")

# ====================== 运行入口 ======================
if __name__ == "__main__":
    main()

case1:生成正态(beta)样本分布

使用均匀分布作为提议分布来生成正态分布样本是拒绝采样的一个典型应用。由于正态分布的定义域是 (,+)(-\infty ,+\infty ),而均匀分布通常定义在有限区间内,因此在实际操作中,我们通常会在一个包含正态分布绝大部分概率质量的截断区间(如 [4,4][-4,4])内进行采样。 

核心参数设定
  1. 目标分布 f(x)f(x):标准正态分布 N(0,1)N(0,1),其概率密度函数为 f(x)=12πex22f(x)=\frac{1}{\sqrt{2\pi }}e^{-\frac{x^{2}}{2}}
  2. 提议分布 g(x)g(x):区间 [4,4][-4,4] 上的均匀分布 U(4,4)U(-4,4),概率密度为常数 g(x)=14(4)=0.125g(x)=\frac{1}{4-(-4)}=0.125
  3. 比较常数 MM:必须满足 Mg(x)f(x)M\cdot g(x)\ge f(x)
    1. 标准正态分布在 x=0x=0 处取得最大值 f(0)0.3989f(0)\approx 0.3989
    2. M×0.1250.3989M\times 0.125\ge 0.3989,解得 M3.19M\ge 3.19。为保险起见,我们取 M=3.2M=3.2

case2:大语言模型(LLM)的对齐与微调

在大模型对齐(LLM Alignment)的语境下,拒绝采样(Rejection Sampling) 是一种将生成模型(建议分布)推向符合人类偏好(目标分布)的有效策略。:

1. 核心组件定义

• 建议分布 p(x)p(x): 当前训练中的 LLM 策略(Policy)。我们利用该模型针对同一个 Prompt 生成 NN 个不同的候选响应。

• 评估指标: 奖励模型(Reward Model, RM)。它作为一个判别器,为这 NN 个样本分别打分 r(x,y)r(x,y),分值高低代表了样本符合人类价值观或逻辑要求的程度。

• 目标分布 1(x)1(x): 这是一个理想的“对齐分布”,即那些奖励得分极高、符合偏好约束的分布。 

2. 实际操作流程(从采样到训练)

在实践中,拒绝采样通常不只是单纯的“拒绝”,而是作为一个数据筛选器来优化模型:

  1. 大规模采样: 使用 LLM 对每个 Prompt 生成 NN 个结果(如 N=1030N=10\sim 30)。
  2. 奖励打分: 利用 RM 对所有结果进行评分。
  3. 样本筛选:
    1. 严格拒绝采样:根据 RM 的得分计算一个接受概率(通常与 exp(r)\exp (r) 成正比),通过随机试验决定是否保留该样本。
    2. Best-of-N (近似做法): 实际上大模型工程中常用其变体——从 NN 个样本中只保留得分最高(或前 KK 个)的样本。
  4. 模型更新(迭代): 将这些被筛选出的“高质量”样本作为新的训练集,通过 有监督微调(SFT) 或 DPO(直接偏好优化) 更新 LLM,使其分布向高分区域偏移。

3. 为什么选择这种方式?

• 避开复杂的强化学习: 拒绝采样能够通过简单的 SFT 达到类似 PPO(强化学习)的效果,但训练更加稳定,不易发生训练崩溃。

• 解决分布偏移: 直接在模型自己生成的样本上筛选,可以确保训练数据处于模型自身的分布内,减少了微调时的“灾难性遗忘”。

• 提升推理上限: 例如 DeepSeek-R1 等模型在研发中使用了拒绝采样来筛选复杂的推理链(CoT),从而显著增强了模型的思维能力。

简而言之,在这种场景下,LLM 负责“想出各种主意”,RM 负责“毙掉烂主意”,最终留下的“好主意”被用来教导模型变得更好。