增强学习--Q-leraning
1 import numpy as np 2 import random 3 from environment import Env 4 from collections import defaultdict 5 6 class QLearningAgent: 7 def __init__(self, actions): 8 # actions = [0, 1, 2, 3] 9 self.actions = actions 10 self.learning_rate = 0.01 11 self.discount_factor = 0.9 12 self.epsilon = 0.1 13 self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#待更新q表 14 15 # update q function with sample <s, a, r, s'> 16 def learn(self, state, action, reward, next_state): 17 current_q = self.q_table[state][action] 18 # using Bellman Optimality Equation to update q function 19 new_q = reward + self.discount_factor * max(self.q_table[next_state]) 20 self.q_table[state][action] += self.learning_rate * (new_q - current_q)#更新公式,off-policy 21 22 # get action for the state according to the q function table 23 # agent pick action of epsilon-greedy policy 24 def get_action(self, state): 25 #epsilon-greedy policy 26 if np.random.rand() < self.epsilon: 27 # take random action 28 action = np.random.choice(self.actions) 29 else: 30 # take action according to the q function table 31 state_action = self.q_table[state] 32 action = self.arg_max(state_action) 33 return action 34 35 @staticmethod 36 def arg_max(state_action): 37 max_index_list = [] 38 max_value = state_action[0] 39 for index, value in enumerate(state_action): 40 if value > max_value: 41 max_index_list.clear() 42 max_value = value 43 max_index_list.append(index) 44 elif value == max_value: 45 max_index_list.append(index) 46 return random.choice(max_index_list) 47 48 if __name__ == "__main__": 49 env = Env() 50 agent = QLearningAgent(actions=list(range(env.n_actions))) 51 52 for episode in range(1000): 53 state = env.reset() 54 55 while True: 56 env.render() 57 58 # take action and proceed one step in the environment 59 action = agent.get_action(str(state)) 60 next_state, reward, done = env.step(action) 61 62 # with sample <s,a,r,s'>, agent learns new q function 63 agent.learn(str(state), action, reward, str(next_state)) 64 65 state = next_state 66 env.print_value_all(agent.q_table) 67 68 # if episode ends, then break 69 if done: 70 break
桔桔桔桔桔桔桔桔桔桔