Double DQN (Double Deep Q-Network) 是对原始 DQN 的改进,旨在减轻 Q-learning 中的过高估计偏差(overestimation bias)。要理解 Double DQN 使用两个不同的模型来计算值(value)和目标(target)并缓解过高估计,需要从 Q-learning 和 DQN 的更新方式说起。
Q-learning 和 DQN 的过高估计问题
在 Q-learning 和 DQN 中,更新 Q 值时使用的目标是通过最大化未来状态的 Q 值来计算的。这种方法容易导致过高估计,因为 Q 值的最大化过程可能会因为噪声或不准确的 Q 值估计而放大值。
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 Q 网络
class QNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, output_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# 初始化主网络和目标网络
input_dim = 4 # 状态维度
output_dim = 2 # 动作维度
online_net = QNetwork(input_dim, output_dim)
target_net = QNetwork(input_dim, output_dim)
# 同步目标网络参数
target_net.load_state_dict(online_net.state_dict())
# 定义优化器
optimizer = optim.Adam(online_net.parameters(), lr=0.001)
# 计算 Double DQN 的目标值
def compute_td_target(reward, next_state, done):
with torch.no_grad():
# 使用主网络选择动作
next_action = torch.argmax(online_net(next_state), dim=1)
# 使用目标网络评估动作
next_q_value = target_net(next_state).gather(1, next_action.unsqueeze(-1)).squeeze(-1)
td_target = reward + (1 - done) * gamma * next_q_value
return td_target
# 假设有一些经验样本
state = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
action = torch.tensor([1])
reward = torch.tensor([1.0])
next_state = torch.tensor([[2.0, 3.0, 4.0, 5.0]])
done = torch.tensor([0])
# 计算当前 Q 值
current_q_value = online_net(state).gather(1, action.unsqueeze(-1)).squeeze(-1)
# 计算目标 Q 值
gamma = 0.99
td_target = compute_td_target(reward, next_state, done)
# 计算损失
loss = nn.MSELoss()(current_q_value, td_target)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每隔一定步数同步目标网络
sync_interval = 1000
step = 0 # 假设这是一个循环内的步数计数器
if step % sync_interval == 0:
target_net.load_state_dict(online_net.state_dict())
-
网络定义:
QNetwork类定义了一个简单的前馈神经网络结构。 -
网络初始化:
online_net和target_net使用相同的网络结构进行初始化。 -
同步参数:使用
target_net.load_state_dict(online_net.state_dict())将主网络的参数复制到目标网络。这通常在训练过程中每隔一定步数进行一次,以保持目标网络参数的稳定性。 -
目标值计算:
- 选择动作:使用主网络
online_net计算当前状态next_state下的最优动作next_action。 - 评估动作:使用目标网络
target_net计算最优动作的 Q 值next_q_value。
- 选择动作:使用主网络
-
损失计算和优化:计算当前 Q 值和目标 Q 值之间的均方误差,并进行反向传播和优化。
-
同步目标网络参数:每隔一定步数(
sync_interval)将主网络的参数同步到目标网络。