DoubleDQN---tensorflow实现

完整代码:https://github.com/zle1992/Reinforcement_Learning_Game

 

 

 

 

开山之作: 《Playing Atari with Deep Reinforcement Learning》(NIPS)

http://export.arxiv.org/pdf/1312.5602

 

 

《Human-level control through deep reinforcementlearnin》 https://www.cs.swarthmore.edu/~meeden/cs63/s15/nature15b.pdf

使用2个网络,减少了相关性,每隔一定时间,替换参数。

《Deep Reinforcement Learning with Double Q-learning》  https://arxiv.org/pdf/1509.06461.pdf

 

 

 

 

 

  1 import os
  2 import numpy as np 
  3 import tensorflow as tf
  4 from abc import ABCMeta, abstractmethod
  5 np.random.seed(1)
  6 tf.set_random_seed(1)
  7 
  8 import logging  # 引入logging模块
  9 logging.basicConfig(level=logging.DEBUG,
 10                     format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')  # logging.basicConfig函数对日志的输出格式及方式做相关配置
 11 # 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上
 12 
 13 tfconfig = tf.ConfigProto()
 14 tfconfig.gpu_options.allow_growth = True
 15 session = tf.Session(config=tfconfig)
 16 
 17 
 18 class DoubleDQNet(object):
 19     __metaclass__ = ABCMeta
 20     """docstring for DeepQNetwork"""
 21     def __init__(self, 
 22             n_actions,
 23             n_features,
 24             learning_rate,
 25             reward_decay,
 26             replace_target_iter,
 27             memory_size,
 28             e_greedy,
 29             e_greedy_increment,
 30             e_greedy_max,
 31             output_graph,
 32             log_dir,
 33             use_doubleQ ,
 34             model_dir,
 35             ):
 36         super(DoubleDQNet, self).__init__()
 37         
 38         self.n_actions = n_actions
 39         self.n_features = n_features
 40         self.learning_rate=learning_rate
 41         self.gamma=reward_decay
 42         self.replace_target_iter=replace_target_iter
 43         self.memory_size=memory_size
 44         self.epsilon=e_greedy
 45         self.epsilon_max=e_greedy_max
 46         self.epsilon_increment=e_greedy_increment
 47         self.output_graph=output_graph
 48         self.lr =learning_rate
 49         
 50         self.log_dir = log_dir
 51         self.use_doubleQ =use_doubleQ
 52         self.model_dir = model_dir 
 53         # total learning step
 54         self.learn_step_counter = 0
 55 
 56 
 57         self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
 58         self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next')
 59 
 60 
 61 
 62 
 63 
 64         self.r = tf.placeholder(tf.float32,[None,],name='r')
 65         self.a = tf.placeholder(tf.int32,[None,],name='a')
 66 
 67 
 68         self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
 69         self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False)
 70         #self.q_eval4next  = tf.stop_gradient(self._build_q_net(self.s_next, scope='eval_net4next', trainable=True))
 71         self.q_eval4next  = self._build_q_net(self.s_next, scope='eval_net4next', trainable=False)
 72         
 73 
 74 
 75 
 76         
 77    
 78 
 79         if self.use_doubleQ:
 80 
 81            
 82             value_i = tf.to_int32(tf.argmax(self.q_eval4next,axis=1))
 83             range_i = tf.range(tf.shape(self.a)[0], dtype=tf.int32)
 84             index_a = tf.stack([range_i, value_i], axis=1)
 85 
 86 
 87             maxq =  tf.gather_nd(params=self.q_next,indices=index_a)
 88        
 89         else:
 90             maxq =  tf.reduce_max(self.q_next, axis=1, name='Qmax_s_')    # shape=(None, )
 91 
 92 
 93         with tf.variable_scope('q_target'):
 94             #只更新最大的那一列
 95             self.q_target = self.r + self.gamma * maxq
 96         with tf.variable_scope('q_eval'):
 97             a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
 98             self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices)    # shape=(None, )
 99         with tf.variable_scope('loss'):
100             self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
101         with tf.variable_scope('train'):
102             self._train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
103 
104 
105 
106         t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
107         e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
108         en_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net4next')
109 
110         with tf.variable_scope("hard_replacement"):
111             self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)]
112 
113         with tf.variable_scope("hard_replacement2"):
114             self.target_replace_op2=[tf.assign(t,e) for t,e in zip(en_params,e_params)]
115        
116         self.sess = tf.Session()
117         if self.output_graph:
118             tf.summary.FileWriter(self.log_dir,self.sess.graph)
119 
120         self.sess.run(tf.global_variables_initializer())
121         
122         self.cost_his =[0]
123         self.cost = 0
124 
125         self.saver = tf.train.Saver()
126 
127         if not os.path.exists(self.model_dir):
128             os.mkdir(self.model_dir)
129 
130         checkpoint = tf.train.get_checkpoint_state(self.model_dir)
131         if checkpoint and checkpoint.model_checkpoint_path:
132             self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
133             print ("Loading Successfully")
134             self.learn_step_counter = int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1
135     @abstractmethod
136     def _build_q_net(self,x,scope,trainable):
137         raise NotImplementedError
138 
139     def learn(self,data):
140 
141         self.sess.run(self.target_replace_op2)
142          # check to replace target parameters
143         if self.learn_step_counter % self.replace_target_iter == 0:
144             self.sess.run(self.target_replace_op)
145             print('\ntarget_params_replaced\n')
146 
147         batch_memory_s = data['s']
148         batch_memory_a =  data['a']
149         batch_memory_r = data['r']
150         batch_memory_s_ = data['s_']
151       
152     
153         
154         _, cost = self.sess.run(
155             [self._train_op, self.loss],
156             feed_dict={
157                 self.s: batch_memory_s,
158                 self.a: batch_memory_a,
159                 self.r: batch_memory_r,
160                 self.s_next: batch_memory_s_,
161             
162             })
163         #self.cost_his.append(cost)
164         self.cost = cost
165         # increasing epsilon
166         if self.epsilon < self.epsilon_max:
167             self.epsilon += self.epsilon_increment 
168         else:
169             self.epsilon = self.epsilon_max
170 
171 
172 
173         self.learn_step_counter += 1
174             # save network every 100000 iteration
175         if self.learn_step_counter % 10000 == 0:
176             self.saver.save(self.sess,self.model_dir,global_step=self.learn_step_counter)
177 
178 
179 
180     def choose_action(self,s): 
181         s = s[np.newaxis,:]
182         aa = np.random.uniform()
183         #print("epsilon_max",self.epsilon_max)
184         if aa < self.epsilon:
185             action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
186             action = np.argmax(action_value)
187         else:
188             action = np.random.randint(0,self.n_actions)
189         return action

 

 

 

参考:

https://github.com/simoninithomas/Deep_reinforcement_learning_Course

https://github.com/spiglerg/DQN_DDQN_Dueling_and_DDPG_Tensorflow/blob/master/modules/dqn.py

posted @ 2019-01-18 13:29  乐乐章  阅读(1629)  评论(0编辑  收藏  举报