(Double Deep Q-Network) ,避免DQN 自举的一种改进方式

104 阅读2分钟

Double DQN (Double Deep Q-Network) 是对原始 DQN 的改进,旨在减轻 Q-learning 中的过高估计偏差(overestimation bias)。要理解 Double DQN 使用两个不同的模型来计算值(value)和目标(target)并缓解过高估计,需要从 Q-learning 和 DQN 的更新方式说起。

Q-learning 和 DQN 的过高估计问题

在 Q-learning 和 DQN 中,更新 Q 值时使用的目标是通过最大化未来状态的 Q 值来计算的。这种方法容易导致过高估计,因为 Q 值的最大化过程可能会因为噪声或不准确的 Q 值估计而放大值。

import torch
import torch.nn as nn
import torch.optim as optim

# 定义 Q 网络
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 初始化主网络和目标网络
input_dim = 4  # 状态维度
output_dim = 2  # 动作维度
online_net = QNetwork(input_dim, output_dim)
target_net = QNetwork(input_dim, output_dim)

# 同步目标网络参数
target_net.load_state_dict(online_net.state_dict())

# 定义优化器
optimizer = optim.Adam(online_net.parameters(), lr=0.001)

# 计算 Double DQN 的目标值
def compute_td_target(reward, next_state, done):
    with torch.no_grad():
        # 使用主网络选择动作
        next_action = torch.argmax(online_net(next_state), dim=1)
        # 使用目标网络评估动作
        next_q_value = target_net(next_state).gather(1, next_action.unsqueeze(-1)).squeeze(-1)
        td_target = reward + (1 - done) * gamma * next_q_value
    return td_target

# 假设有一些经验样本
state = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
action = torch.tensor([1])
reward = torch.tensor([1.0])
next_state = torch.tensor([[2.0, 3.0, 4.0, 5.0]])
done = torch.tensor([0])

# 计算当前 Q 值
current_q_value = online_net(state).gather(1, action.unsqueeze(-1)).squeeze(-1)

# 计算目标 Q 值
gamma = 0.99
td_target = compute_td_target(reward, next_state, done)

# 计算损失
loss = nn.MSELoss()(current_q_value, td_target)

# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 每隔一定步数同步目标网络
sync_interval = 1000
step = 0  # 假设这是一个循环内的步数计数器
if step % sync_interval == 0:
    target_net.load_state_dict(online_net.state_dict())

  • 网络定义QNetwork 类定义了一个简单的前馈神经网络结构。

  • 网络初始化online_nettarget_net 使用相同的网络结构进行初始化。

  • 同步参数:使用 target_net.load_state_dict(online_net.state_dict()) 将主网络的参数复制到目标网络。这通常在训练过程中每隔一定步数进行一次,以保持目标网络参数的稳定性。

  • 目标值计算

    • 选择动作:使用主网络 online_net 计算当前状态 next_state 下的最优动作 next_action
    • 评估动作:使用目标网络 target_net 计算最优动作的 Q 值 next_q_value
  • 损失计算和优化:计算当前 Q 值和目标 Q 值之间的均方误差,并进行反向传播和优化。

  • 同步目标网络参数:每隔一定步数(sync_interval)将主网络的参数同步到目标网络。