强化学习:DQN玩转CartPole游戏

383 阅读3分钟

强化学习:DQN玩转CartPole游戏

1. CartPole环境与强化学习基础

1.1 环境介绍

CartPole是OpenAI Gym中的经典控制问题,目标是通过左右移动小车保持杆子竖直:

graph TD
    A[状态空间] --> B[车位置 -2.4~2.4]
    A --> C[车速度 -∞~∞]
    A --> D[杆角度 -41.8°~41.8°]
    A --> E[杆角速度 -∞~∞]
    F[动作空间] --> G[向左推 0]
    F --> H[向右推 1]
    style A fill:#9f9,stroke:#333
    style F fill:#f99,stroke:#333

1.2 强化学习基本概念

  • 状态(State)stR4s_t \in \mathbb{R}^4
  • 动作(Action)at{0,1}a_t \in \{0, 1\}
  • 奖励(Reward):每步存活获得+1
  • 目标:最大化累积奖励t=0Tγtrt\sum_{t=0}^T \gamma^t r_t

2. DQN算法原理

2.1 Q-Learning更新公式

Q(st,at)Q(st,at)+α[rt+1+γmaxaQ(st+1,a)Q(st,at)]Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha[r_{t+1} + \gamma \max_a Q(s_{t+1},a) - Q(s_t,a_t)]

2.2 深度Q网络改进

graph LR
    A[传统Q-Learning] --> B[状态空间爆炸]
    B --> C[深度网络拟合]
    C --> D[经验回放]
    D --> E[目标网络]
    style C fill:#99f,stroke:#333
    style E fill:#99f,stroke:#333
2.2.1 关键技术组件
  1. 经验回放(Experience Replay):打破数据相关性
  2. 目标网络(Target Network):稳定训练目标
  3. ε-贪婪策略:平衡探索与利用

3. PyTorch实现DQN

3.1 Q网络定义

import torch
import torch.nn as nn
import torch.optim as optim

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
    
    def forward(self, x):
        return self.fc(x)

3.2 经验回放缓冲区

from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

3.3 DQN智能体实现

class DQNAgent:
    def __init__(self, env, gamma=0.99, lr=1e-3):
        self.env = env
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        
        self.policy_net = DQN(env.observation_space.shape[0], 
                             env.action_space.n)
        self.target_net = DQN(env.observation_space.shape[0],
                             env.action_space.n)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(10000)
    
    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        with torch.no_grad():
            return self.policy_net(state).argmax().item()
    
    def update_model(self, batch_size):
        if len(self.memory) < batch_size:
            return
        
        # 从缓冲区采样
        transitions = self.memory.sample(batch_size)
        batch = list(zip(*transitions))
        
        # 转换为张量
        state_batch = torch.FloatTensor(batch[0])
        action_batch = torch.LongTensor(batch[1]).unsqueeze(1)
        reward_batch = torch.FloatTensor(batch[2])
        next_state_batch = torch.FloatTensor(batch[3])
        done_batch = torch.FloatTensor(batch[4])
        
        # 计算当前Q值
        q_values = self.policy_net(state_batch).gather(1, action_batch)
        
        # 计算目标Q值
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach()
        expected_q = reward_batch + (1 - done_batch) * self.gamma * next_q_values
        
        # 计算损失
        loss = F.mse_loss(q_values, expected_q.unsqueeze(1))
        
        # 优化模型
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新ε
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

4. 训练流程与结果

4.1 训练循环

def train(env, agent, episodes=500, batch_size=64):
    rewards = []
    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        
        while True:
            state_tensor = torch.FloatTensor(state)
            action = agent.select_action(state_tensor)
            
            next_state, reward, done, _ = env.step(action)
            agent.memory.push(state, action, reward, next_state, done)
            
            agent.update_model(batch_size)
            
            state = next_state
            total_reward += reward
            
            if done:
                break
        
        # 更新目标网络
        if ep % 10 == 0:
            agent.target_net.load_state_dict(agent.policy_net.state_dict())
        
        rewards.append(total_reward)
        print(f"Episode {ep+1}/{episodes}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
    
    return rewards

4.2 训练结果分析

import matplotlib.pyplot as plt

env = gym.make('CartPole-v1')
agent = DQNAgent(env)
rewards = train(env, agent)

# 绘制学习曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training Progress')
plt.show()
4.2.1 典型训练曲线
graph LR
    A[初始随机探索] --> B[逐渐稳定]
    B --> C[达到最大奖励]
    style A fill:#f99,stroke:#333
    style C fill:#9f9,stroke:#333

5. 高级改进技巧

5.1 Double DQN

修改目标Q值计算:

next_actions = self.policy_net(next_state_batch).max(1)[1]
next_q_values = self.target_net(next_state_batch).gather(1, next_actions.unsqueeze(1)).squeeze(1)

5.2 Dueling DQN

修改网络结构:

class DuelingDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU()
        )
        self.advantage = nn.Linear(128, output_dim)
        self.value = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.feature(x)
        advantage = self.advantage(x)
        value = self.value(x)
        return value + advantage - advantage.mean()

5.3 性能对比

方法平均奖励收敛速度稳定性
原始DQN195300次
Double DQN200250次
Dueling DQN200+200次很高

6. 常见问题解答

Q: 为什么需要目标网络?

  • 防止Q值估计的快速变化导致训练不稳定
  • 提供相对固定的目标值进行学习

Q: 如何选择ε衰减速度?

  • 初始阶段保持较高探索(ε=1.0)
  • 逐步衰减到最小值(ε_min=0.01)
  • 典型衰减率:0.995~0.999

Q: 如何处理稀疏奖励问题?

  • 使用优先经验回放(Prioritized Experience Replay)
  • 引入内在好奇心模块
  • 调整奖励函数设计

附录:核心数学推导

Bellman最优方程

Q(s,a)=E[r+γmaxaQ(s,a)s,a]Q^*(s,a) = \mathbb{E}[r + \gamma \max_{a'} Q^*(s',a') | s,a]

损失函数推导

L(θ)=E(s,a,r,s)D[(r+γmaxaQtarget(s,a)Q(s,a;θ))2]\mathcal{L}(\theta) = \mathbb{E}_{(s,a,r,s') \sim D}[(r + \gamma \max_{a'} Q_{\text{target}}(s',a') - Q(s,a;\theta))^2]

梯度更新公式

θθαθL(θ)\theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta)


最佳实践建议

  1. 使用Frame Stacking处理部分可观测问题
  2. 定期保存模型检查点
  3. 使用W&B或TensorBoard监控训练
  4. 尝试不同的网络架构(CNN、LSTM等)

完整代码示例可在GitHub仓库获取,包含可视化界面和进阶实现。通过调整超参数,可以轻松迁移到Atari等更复杂环境!