深度增强学习--Actor Critic

Actor Critic value-based和policy-based的结合


  1 import sys
  2 import gym
  3 import pylab
  4 import numpy as np
  5 from keras.layers import Dense
  6 from keras.models import Sequential
  7 from keras.optimizers import Adam
  9 EPISODES = 1000
 12 # A2C(Advantage Actor-Critic) agent for the Cartpole
 13 # actor-critic算法结合了value-based和policy-based方法
 14 class A2CAgent:
 15     def __init__(self, state_size, action_size):
 16         # if you want to see Cartpole learning, then change to True
 17         self.render = True
 18         self.load_model = False
 19         # get size of state and action
 20         self.state_size = state_size
 21         self.action_size = action_size
 22         self.value_size = 1
 24         # These are hyper parameters for the Policy Gradient
 25         self.discount_factor = 0.99
 26         self.actor_lr = 0.001
 27         self.critic_lr = 0.005
 29         # create model for policy network
 30         self.actor = self.build_actor()
 31         self.critic = self.build_critic()
 33         if self.load_model:
 34             self.actor.load_weights("./save_model/cartpole_actor.h5")
 35             self.critic.load_weights("./save_model/cartpole_critic.h5")
 37     # approximate policy and value using Neural Network
 38     # actor: state is input and probability of each action is output of model
 39     def build_actor(self):#actor网络:state-->action
 40         actor = Sequential()
 41         actor.add(Dense(24, input_dim=self.state_size, activation='relu',
 42                         kernel_initializer='he_uniform'))
 43         actor.add(Dense(self.action_size, activation='softmax',
 44                         kernel_initializer='he_uniform'))
 45         actor.summary()
 46         # See note regarding crossentropy in cartpole_reinforce.py
 47         actor.compile(loss='categorical_crossentropy',
 48                       optimizer=Adam(lr=self.actor_lr))
 49         return actor
 51     # critic: state is input and value of state is output of model
 52     def build_critic(self):#critic网络:state-->value,Q值
 53         critic = Sequential()
 54         critic.add(Dense(24, input_dim=self.state_size, activation='relu',
 55                          kernel_initializer='he_uniform'))
 56         critic.add(Dense(self.value_size, activation='linear',
 57                          kernel_initializer='he_uniform'))
 58         critic.summary()
 59         critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
 60         return critic
 62     # using the output of policy network, pick action stochastically
 63     def get_action(self, state):
 64         policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
 65         return np.random.choice(self.action_size, 1, p=policy)[0]
 67     # update policy network every episode
 68     def train_model(self, state, action, reward, next_state, done):
 69         target = np.zeros((1, self.value_size))#(1,1)
 70         advantages = np.zeros((1, self.action_size))#(1, 2)
 72         value = self.critic.predict(state)[0]#critic网络预测的当前q值
 73         next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值
 75         '''
 76         理解下面部分
 77         '''
 78         if done:
 79             advantages[0][action] = reward - value
 80             target[0][0] = reward
 81         else:
 82             advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
 83             target[0][0] = reward + self.discount_factor * next_value#critic网络
 85         self.actor.fit(state, advantages, epochs=1, verbose=0)
 86         self.critic.fit(state, target, epochs=1, verbose=0)
 89 if __name__ == "__main__":
 90     # In case of CartPole-v1, maximum length of episode is 500
 91     env = gym.make('CartPole-v1')
 92     # get size of state and action from environment
 93     state_size = env.observation_space.shape[0]
 94     action_size = env.action_space.n
 96     # make A2C agent
 97     agent = A2CAgent(state_size, action_size)
 98     scores, episodes = [], []
100     for e in range(EPISODES):
101         done = False
102         score = 0
103         state = env.reset()
104         state = np.reshape(state, [1, state_size])
106         while not done:
107             if agent.render:
108                 env.render()
110             action = agent.get_action(state)
111             next_state, reward, done, info = env.step(action)
112             next_state = np.reshape(next_state, [1, state_size])
113             # if an action make the episode end, then gives penalty of -100
114             reward = reward if not done or score == 499 else -100
116             agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次
118             score += reward
119             state = next_state
121             if done:
122                 # every episode, plot the play time
123                 score = score if score == 500.0 else score + 100
124                 scores.append(score)
125                 episodes.append(e)
126                 pylab.plot(episodes, scores, 'b')
127                 pylab.savefig("./save_graph/cartpole_a2c.png")
128                 print("episode:", e, "  score:", score)
130                 # if the mean of scores of last 10 episode is bigger than 490
131                 # stop training
132                 if np.mean(scores[-min(10, len(scores)):]) > 490:
133                     sys.exit()
135         # save the model
136         if e % 50 == 0:
137             agent.actor.save_weights("./save_model/cartpole_actor.h5")
138             agent.critic.save_weights("./save_model/cartpole_critic.h5")


