浅谈TD3:从算法原理到代码实现
浅谈TD3:从算法原理到代码实现
作者:Xingzhe.AI
来自:行者AI
引言
众所周知,在基于价值学习的强化学习算法中,如DQN,函数近似误差是导致Q值高估和次优策略的原因。我们表明这个问题依然在AC框架中存在,并提出了新的机制去最小化它对演员(策略函数)和评论家(估值函数)的影响。我们的算法建立在双Q学习的基础上,通过选取两个估值函数中的较小值,从而限制它对Q值的过高估计。(出自TD3论文摘要)
1. 什么是TD3
TD3是Twin Delayed Deep Deterministic policy gradient algorithm的全称。TD3全称中Deep Deterministic policy gradient algorithm就是DDPG的全称。那么DDPG和TD3有何渊源呢?其实简单的说,TD3是DDPG的一个优化版本。
1.1 TD3为什么被提出
在强化学习中,对于离散化的动作的学习,都是以DQN为基础的,DQN则是通过的的方式去选择动作,往往都会过大的估计价值函数,从而造成误差。在连续的动作控制的AC框架中,如果每一步都采用这种方式去估计,导致误差一步一步的累加,导致不能找到最优策略,最终使算法不能得到收敛。
1.2 TD3在DDPG的基础上都做了些什么
-
使用两个Critic网络。使用两个网络对动作价值函数进行估计,(这Double DQN 的思想差不多)。在训练的时候选择作为估计值。
-
使用软更新的方式 。不再采用直接复制,而是使用 的方式更新网络参数。
-
使用策略噪音。使用Epsilon-Greedy在探索的时候使用了探索噪音。(还是用了策略噪声,在更新参数的时候,用于平滑策略期望)
-
使用延迟学习。Critic网络更新的频率要比Actor网络更新的频率要大。
-
使用梯度截取。将Actor的参数更新的梯度截取到某个范围内。
2. TD3算法思路
图1. TD3算法流程
TD3算法的大致思路,首先初始化3个网络,分别为 ,参数为,在初始化3个Target网络,分别将开始初始化的3个网络参数分别对应的复制给target网络。 。初始化Replay Buffer 。
然后通过循环迭代,一次次找到最优策略。每次迭代,在选择action的值的时候加入了噪音,使,,然后将放入,当达到一定的值时候。
然后随机从中Sample出Mini-Batch个数据,通过,,计算出状态下对应的Action的值,通过,计算出,获取,为的值。
通过贝尔曼方程计算的值,通过两个Current网络根据分别计算出当前的值,在将两个当前网络的值和值通过MSE计算Loss,更新参数。
Critic网络更新之后,Actor网络则采用了延时更新,(一般采用Critic更新2次,Actor更新1次)。通过梯度上升的方式更新Actor网络。通过软更新的方式,更新target网络。
-
为什么在更新Critic网络时,在计算Action值的时候加入噪音,是为了平滑前面加入的噪音。
-
贝尔曼方程:针对一个连续的MRP(Markov Reward Process)的过程(连续的状态奖励过程),状态转移到下一个状态 的概率的固定的,与前面的几轮状态无关。其中,表示一个对当前状态state 进行估值的函数。一般为趋近于1,但是小于1。
图2. 贝尔曼方程
3. 代码实现
代码主要是根据DDPG的代码以及TD3的论文复现的,使用的是Pytorch1.7实现的。
3.1 搭建网络结构
Q1网络结构主要是用于更新Actor网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.f1 = nn.Linear(state_dim, 256)
self.f2 = nn.Linear(256, 128)
self.f3 = nn.Linear(128, action_dim)
self.max_action = max_action
def forward(self,x):
x = self.f1(x)
x = F.relu(x)
x = self.f2(x)
x = F.relu(x)
x = self.f3(x)
return torch.tanh(x) * self.max_action
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic,self).__init__()
self.f11 = nn.Linear(state_dim+action_dim, 256)
self.f12 = nn.Linear(256, 128)
self.f13 = nn.Linear(128, 1)
self.f21 = nn.Linear(state_dim + action_dim, 256)
self.f22 = nn.Linear(256, 128)
self.f23 = nn.Linear(128, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
x = self.f11(sa)
x = F.relu(x)
x = self.f12(x)
x = F.relu(x)
Q1 = self.f13(x)
x = self.f21(sa)
x = F.relu(x)
x = self.f22(x)
x = F.relu(x)
Q2 = self.f23(x)
return Q1, Q2
3.2 定义网络
self.actor = Actor(self.state_dim, self.action_dim, self.max_action)
self.target_actor = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
#定义critic网络
self.critic = Critic(self.state_dim, self.action_dim)
self.target_critic = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
3.3 更新网络
更新网络采用软更新,延迟更新等方式
def learn(self):
self.total_it += 1
data = self.buffer.smaple(size=128)
state, action, done, state_next, reward = data
with torch.no_grad:
noise = (torch.rand_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_action = (self.target_actor(state_next) + noise).clamp(-self.max_action, self.max_action)
target_Q1,target_Q2 = self.target_critic(state_next, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
critic_loss.backward()
self.critic_optimizer.step()
if self.total_it % self.policy_freq == 0:
q1,q2 = self.critic(state, self.actor(state))
actor_loss = -torch.min(q1, q2).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
4. 总结
TD3是DDPG的一个升级版,在解决很多的问题上,效果要比DDPG的效果好的多,无论是训练速度,还是结果都有显著的提高。
图3. 算法效果对比
5. 资料
行者AI(成都潜在人工智能科技有限公司,xingzhe.ai)致力于使用人工智能和机器学习技术提高游戏和文娱行业的生产力,并持续改善行业的用户体验。我们有内容安全团队、游戏机器人团队、数据平台团队、智能音乐团队和自动化测试团队。 > >如果您对世界拥有强烈的好奇心,不畏惧挑战性问题;能够容忍摸索过程中的各种不确定性、并且坚持下去;能够寻找创新的方式来应对挑战,并同时拥有事无巨细的责任心以确保解决方案的有效执行。那么请将您的个人简历、相关的工作成果及您具体感兴趣的职位提交给我们。我们欢迎拥抱挑战、并具有创新思维的人才加入我们的团队。请联系:hr@xingzhe.ai > >如果您有任何关于内容安全、游戏机器人、数据平台、智能音乐和自动化测试方面的需求,我们也非常荣幸能为您服务。可以联系:contact@xingzhe.ai
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· 展开说说关于C#中ORM框架的用法!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?