maddpg学习
模仿的是PARL的example修改成基于torch的模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MAModle:
'''
提供的代码是用Python编写的,并使用PyTorch库定义了一个包含行动者(Actor)和评论者(Critic)的多代理模型(MAModle)。
MAModle类有获取代理策略和价值的方法,以及获取行动者和评论者参数的方法。
Actor类是一个PyTorch模块,它接收一个观察值并输出一个动作。它有一个前馈神经网络,包含两个隐藏层,每层有64个神经元。如果动作是连续的,它还会输出标准偏差。
Critic类是另一个PyTorch模块,它接收一个状态和动作并输出一个Q值。它也有一个前馈神经网络,包含两个隐藏层,每层有64个神经元。
Actor和Critic类中的forward方法定义了神经网络的前向传播。
'''
def __init__(self,
obs_dim,
act_dim,
critic_dim,
continuous_actions=False
):
super(MAModle,self).__init__()
self.actor=Actor(obs_dim,act_dim,continuous_actions)
self.critic=Critic(critic_dim)
def policy(self,obs):
return self.actor(obs)
def value(self,obs,act):
return self.critic(obs,act)
def get_actor_param(self):
return self.actor.parameters()
def get_critic_param(self):
return self.critic.parameters()
# input: agent_i_obs_dim
# output: agent_i_action_dim
class Actor(nn.Module):
def __init__(self, obs_dim,act_dim,continuous_actions=False):
super(Actor, self).__init__()
self.continuous_actions=continuous_actions
self.fc1 = nn.Linear(obs_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, act_dim)
if continuous_actions:
self.std = nn.Linear(64, act_dim)
def forward(self, x):
hid1 = F.relu(self.fc1(x))
hid2 = F.relu(self.fc2(hid1))
means = self.fc3(hid2)
if self.continuous_actions:
std=self.std(hid2)
return (means,std)
return means
# input: all_obs_dim+all_action_dim
# output: 1 (Q-value)
class Critic(nn.Module):
def __init__(self, critic_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(critic_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, state, action):
x = torch.cat([state,action], dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_value = self.fc3(x)
q_value = torch.squeeze(q_value,dim=1)
return q_value
if __name__=="__main__":
# 创建一个MAModle实例
ma_model = MAModle(obs_dim=10, act_dim=2, critic_dim=12)
# 创建一些模拟的观察值和动作
obs = torch.randn(4, 10)
act = torch.randn(4, 2)
# 测试policy方法
print(ma_model.policy(obs).shape)
# 测试value方法
print(ma_model.value(obs, act).shape)
# 测试get_actor_param和get_critic_param方法
# print(list(ma_model.get_actor_param()))
# print(list(ma_model.get_critic_param()))