关于ppo中针对MLP和RNN两种不同网络结构的数据处理与采样方法

在RL中,需要对数据进行采样,因此如何构造可采样的数据或数据块,则是需要关注的问题:

 

if self.actor_critic.is_recurrent:
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch)
else:
data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)
for sample in data_generator:
obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \
adv_targ = sample

# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
obs_batch, recurrent_hidden_states_batch, masks_batch,
actions_batch)
(以上代码来自于ppo_atari,pytorch-a2c-ppo-acktr-gail-master,这两个项目中的代码都是一致的,涉及到针对MLP\CNN\RNN的数据的生成和采样)
如上,我们对数据进行重新生成和采样,其中 data_generator 是重新生成的数据,且其针对 self.actor_critic 是否是 network,有不同的生成策略,但采样方法是一致的。

PPO中存储的其实也是8个process的128个step的数据块,即buffer中存储的是8*128个数据。如果说self.actor_critic不是RNN,则数据是可以被打乱采样,不需要考虑每个tuple数据之间的时序
依赖关系,可如果self.actor_critic是RNN,则我们需要考虑数据之间的时序关联性,因此,我们需要在处理、采样数据时保留数据之间的时序关系。

(1)在self.actor_critic不是RNN的情况下:
    data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)

其中:
def feed_forward_generator(self,
                           advantages,
num_mini_batch=None,
mini_batch_size=None):

num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps

if mini_batch_size is None:
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
"".format(num_processes, num_steps, num_processes * num_steps,
num_mini_batch))

mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)),
mini_batch_size,
drop_last=True)

for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(
-1, self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1,
self.actions.size(-1))[indices]
value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1,
1)[indices]

if advantages is None:
adv_targ = None
else:
adv_targ = advantages.view(-1, 1)[indices]

yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ

其中,函数 sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)),
mini_batch_size,
drop_last=True)
  
  则是把通过把数据分成 mini_batch_size(文中是4) 份,每份有batch_size(文中是256)个数据。在随后的训练中,则是每个data_generator都有4次数据训练,每次的数据都是256个,都不一样。


(2)在self.actor_critic是RNN的情况下:
    data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)

def recurrent_generator(self, advantages, num_mini_batch):
num_processes = self.rewards.size(1)
assert num_processes >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"to be greater than or equal to the number of "
"PPO mini batches ({}).".format(num_processes, num_mini_batch))
num_envs_per_batch = num_processes // num_mini_batch
perm = torch.randperm(num_processes)

for start_ind in range(0, num_processes, num_envs_per_batch):
obs_batch = []
recurrent_hidden_states_batch = []
actions_batch = []
value_preds_batch = []
return_batch = []
masks_batch = []
old_action_log_probs_batch = []
adv_targ = []

for offset in range(num_envs_per_batch):
ind = perm[start_ind + offset]

obs_batch.append(self.obs[:-1, ind])
recurrent_hidden_states_batch.append(
self.recurrent_hidden_states[0:1, ind])
actions_batch.append(self.actions[:, ind])
value_preds_batch.append(self.value_preds[:-1, ind])
return_batch.append(self.returns[:-1, ind])
masks_batch.append(self.masks[:-1, ind])
old_action_log_probs_batch.append(
self.action_log_probs[:, ind])
adv_targ.append(advantages[:, ind])

T, N = self.num_steps, num_envs_per_batch

# These are all tensors of size (T, N, -1)
obs_batch = torch.stack(obs_batch, 1)
actions_batch = torch.stack(actions_batch, 1)
value_preds_batch = torch.stack(value_preds_batch, 1)
return_batch = torch.stack(return_batch, 1)
masks_batch = torch.stack(masks_batch, 1)
old_action_log_probs_batch = torch.stack(
old_action_log_probs_batch, 1)
adv_targ = torch.stack(adv_targ, 1)

# States is just a (N, -1) tensor
recurrent_hidden_states_batch = torch.stack(
recurrent_hidden_states_batch, 1).view(N, -1)

# Flatten the (T, N, ...) tensors to (T * N, ...)
obs_batch = _flatten_helper(T, N, obs_batch)
actions_batch = _flatten_helper(T, N, actions_batch)
value_preds_batch = _flatten_helper(T, N, value_preds_batch)
return_batch = _flatten_helper(T, N, return_batch)
masks_batch = _flatten_helper(T, N, masks_batch)
old_action_log_probs_batch = _flatten_helper(T, N, \
old_action_log_probs_batch)
adv_targ = _flatten_helper(T, N, adv_targ)

yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ

其中,8个process的数据块,每个数据块有128个数据,我们首先以数据块的形式打乱这8个数据块,之后随机抽取2个数据块,组成128*2的数据块,这样就保持了数据之间的时序性,同时将数据分成4份,
进行4次采样和训练,确保每个ppo train_step都对当前buffer中的所有的数据都进行了训练。



posted @ 2022-08-27 10:22  呦呦南山  阅读(594)  评论(0编辑  收藏  举报