DeepNetwork---tensorflow实现

 

https://github.com/zle1992/Reinforcement_Learning_Game

 

 

 

 

 

DeepQNetwork.py
  1 import numpy as np 
  2 import tensorflow as tf
  3 from abc import ABCMeta, abstractmethod
  4 np.random.seed(1)
  5 tf.set_random_seed(1)
  6 
  7 import logging  # 引入logging模块
  8 logging.basicConfig(level=logging.DEBUG,
  9                     format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # logging.basicConfig函数对日志的输出格式及方式做相关配置
 10 # 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上
 11 
 12 tfconfig = tf.ConfigProto()
 13 tfconfig.gpu_options.allow_growth = True
 14 session = tf.Session(config=tfconfig)
 15 
 16 
 17 class DeepQNetwork(object):
 18     __metaclass__ = ABCMeta
 19     """docstring for DeepQNetwork"""
 20     def __init__(self, 
 21             n_actions,
 22             n_features,
 23             learning_rate,
 24             reward_decay,
 25             e_greedy,
 26             replace_target_iter,
 27             memory_size,
 28             e_greedy_increment,
 29             output_graph,
 30             log_dir,
 31             ):
 32         super(DeepQNetwork, self).__init__()
 33         
 34         self.n_actions = n_actions
 35         self.n_features = n_features
 36         self.learning_rate=learning_rate
 37         self.gamma=reward_decay
 38         self.epsilon_max=e_greedy
 39         self.replace_target_iter=replace_target_iter
 40         self.memory_size=memory_size
 41         self.epsilon_increment=e_greedy_increment
 42         self.output_graph=output_graph
 43         self.lr =learning_rate
 44         # total learning step
 45         self.learn_step_counter = 0
 46         self.log_dir = log_dir
 47        
 48  
 49 
 50         self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
 51         self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next')
 52 
 53         self.r = tf.placeholder(tf.float32,[None,],name='r')
 54         self.a = tf.placeholder(tf.int32,[None,],name='a')
 55 
 56 
 57         self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
 58         self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False)
 59 
 60 
 61 
 62         with tf.variable_scope('q_target'):
 63             self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_')    # shape=(None, )
 64         with tf.variable_scope('q_eval'):
 65             a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
 66             self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices)    # shape=(None, )
 67         with tf.variable_scope('loss'):
 68             self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
 69         with tf.variable_scope('train'):
 70             self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
 71 
 72 
 73 
 74         t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
 75         e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
 76 
 77         with tf.variable_scope("hard_replacement"):
 78             self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)]
 79 
 80 
 81        
 82         self.sess = tf.Session()
 83         if self.output_graph:
 84             tf.summary.FileWriter(self.log_dir,self.sess.graph)
 85 
 86         self.sess.run(tf.global_variables_initializer())
 87         
 88         self.cost_his =[]
 89 
 90     @abstractmethod
 91     def _build_q_net(self,x,scope,trainable):
 92         raise NotImplementedError
 93 
 94     def learn(self,data):
 95 
 96 
 97          # check to replace target parameters
 98         if self.learn_step_counter % self.replace_target_iter == 0:
 99             self.sess.run(self.target_replace_op)
100             print('\ntarget_params_replaced\n')
101 
102         batch_memory_s = data['s'], 
103         batch_memory_a =  data['a'], 
104         batch_memory_r = data['r'], 
105         batch_memory_s_ = data['s_'], 
106         _, cost = self.sess.run(
107             [self._train_op, self.loss],
108             feed_dict={
109                 self.s: batch_memory_s,
110                 self.a: batch_memory_a,
111                 self.r: batch_memory_r,
112                 self.s_next: batch_memory_s_,
113             })
114         self.cost_his.append(cost)
115 
116         # increasing epsilon
117         self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
118         self.learn_step_counter += 1
119 
120 
121 
122 
123     def choose_action(self,s): 
124         s = s[np.newaxis,:]
125         aa = np.random.uniform()
126         #print("epsilon_max",self.epsilon_max)
127         if aa < self.epsilon_max:
128             action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
129             action = np.argmax(action_value)
130         else:
131             action = np.random.randint(0,self.n_actions)
132         return action

 

Memory.py
 1 import numpy as np 
 2 np.random.seed(1)
 3 class Memory(object):
 4     """docstring for Memory"""
 5     def __init__(self,
 6             n_actions,
 7             n_features,
 8             memory_size):
 9         super(Memory, self).__init__()
10         self.memory_size = memory_size
11         self.cnt =0 
12 
13         self.s = np.zeros([memory_size]+n_features)
14         self.a = np.zeros([memory_size,])
15         self.r =  np.zeros([memory_size,])
16         self.s_ = np.zeros([memory_size]+n_features)
17         
18     def store_transition(self,s, a, r, s_):
19         #logging.info('store_transition')
20         index = self.cnt % self.memory_size
21         self.s[index] = s
22         self.a[index] = a
23         self.r[index] =  r
24         self.s_[index] =s_
25         self.cnt+=1
26 
27     def sample(self,n):
28         #logging.info('sample')
29         #assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
30         N = min(self.memory_size,self.cnt)
31         indices = np.random.choice(N,size=n)
32         d ={}
33         d['s'] = self.s[indices][0]
34         d['s_'] = self.s_[indices][0]
35         d['r'] = self.r[indices][0]
36         d['a'] = self.a[indices][0]
37         return d

 

 

主函数

  1 import gym
  2 import numpy as np 
  3 import tensorflow as tf
  4 
  5 from Memory import Memory
  6 from DeepQNetwork import DeepQNetwork
  7 
  8 np.random.seed(1)
  9 tf.set_random_seed(1)
 10 
 11 import logging  # 引入logging模块
 12 logging.basicConfig(level=logging.DEBUG,
 13                     format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # logging.basicConfig函数对日志的输出格式及方式做相关配置
 14 # 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上
 15 
 16 tfconfig = tf.ConfigProto()
 17 tfconfig.gpu_options.allow_growth = True
 18 session = tf.Session(config=tfconfig)
 19 
 20 class DeepQNetwork4CartPole(DeepQNetwork):
 21     """docstring for ClassName"""
 22     def __init__(self, **kwargs):
 23         super(DeepQNetwork4CartPole, self).__init__(**kwargs)
 24     
 25     def _build_q_net(self,x,scope,trainable):
 26         w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)
 27 
 28         with tf.variable_scope(scope):
 29             e1 = tf.layers.dense(inputs=x, 
 30                     units=32, 
 31                     bias_initializer = b_initializer,
 32                     kernel_initializer=w_initializer,
 33                     activation = tf.nn.relu,
 34                     trainable=trainable)  
 35             q = tf.layers.dense(inputs=e1, 
 36                     units=self.n_actions, 
 37                     bias_initializer = b_initializer,
 38                     kernel_initializer=w_initializer,
 39                     activation = tf.nn.sigmoid,
 40                     trainable=trainable) 
 41 
 42         return q  
 43         
 44 
 45 
 46 
 47 batch_size = 64
 48 
 49 memory_size  =2000
 50 #env = gym.make('Breakout-v0') #离散
 51 env = gym.make('CartPole-v0') #离散
 52 
 53 
 54 n_features= list(env.observation_space.shape)
 55 n_actions= env.action_space.n
 56 
 57 env = env.unwrapped
 58 
 59 def run():
 60    
 61     RL = DeepQNetwork4CartPole(
 62         n_actions=n_actions,
 63         n_features=n_features,
 64         learning_rate=0.01,
 65         reward_decay=0.9,
 66         e_greedy=0.9,
 67         replace_target_iter=200,
 68         memory_size=memory_size,
 69         e_greedy_increment=None,
 70         output_graph=True,
 71         log_dir = 'log/DeepQNetwork4CartPole/',
 72         )
 73 
 74     memory = Memory(n_actions,n_features,memory_size=memory_size)
 75   
 76 
 77     step = 0
 78     ep_r = 0
 79     for episode in range(2000):
 80         # initial observation
 81         observation = env.reset()
 82 
 83         while True:
 84             
 85 
 86             # RL choose action based on observation
 87             action = RL.choose_action(observation)
 88             # logging.debug('action')
 89             # print(action)
 90             # RL take action and get_collectiot next observation and reward
 91             observation_, reward, done, info=env.step(action) # take a random action
 92             
 93             # the smaller theta and closer to center the better
 94             x, x_dot, theta, theta_dot = observation_
 95             r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
 96             r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
 97             reward = r1 + r2
 98 
 99 
100 
101 
102             memory.store_transition(observation, action, reward, observation_)
103             
104             
105             if (step > 200) and (step % 5 == 0):
106                
107                 data = memory.sample(batch_size)
108                 RL.learn(data)
109                 #print('step:%d----reward:%f---action:%d'%(step,reward,action))
110             # swap observation
111             observation = observation_
112             ep_r += reward
113             # break while loop when end of this episode
114             if(episode>700): 
115                 env.render()  # render on the screen
116             if done:
117                 print('episode: ', episode,
118                       'ep_r: ', round(ep_r, 2),
119                       ' epsilon: ', round(RL.epsilon_max, 2))
120                 ep_r = 0
121 
122                 break
123             step += 1
124 
125     # end of game
126     print('game over')
127     env.destroy()
128 
129 def main():
130  
131     run()
132 
133 
134 
135 if __name__ == '__main__':
136     main()
137     #run2()

 

posted @ 2019-01-08 22:37  乐乐章  阅读(562)  评论(0编辑  收藏  举报