AI | 强化学习 | qlearning

AI | 强化学习 | qlearning

之前跟着莫烦python用numpy和pandas来做强化学习的qtable,感觉pandas太反人类了,这次把他课上的例子用python原生的字典来做qtable重新写了一份,便于理解。

代码如下:

import time
import random

N_STATES = 10       # 世界的最大长度
MAX_EPISODES = 15
FRESH_TIME = 0.01       # 刷新时间

class QLearning():
	def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
		self.actions = actions
		self.lr = learning_rate
		self.gamma = reward_decay
		self.epsilon = e_greedy
		# 用字典来做q表
		'''
		{
			'state1': {'action1': $reward1, 'action2': $reward2}
			...
		}
		
		'''
		self.q_table = {}
	
	def check_state_exist(self, state):
		if state not in self.q_table:
			self.q_table[state] = {}
			for action in self.actions:
				self.q_table[state][action] = 0.0
	
	def choose_action(self, observation):
		# 查看情况是否存在
		self.check_state_exist(observation)
		# 90%的概率选择最优解
		if random.random() < self.epsilon:
			state_action = self.q_table[observation]   # 取出q表行,找最大值
			_max = max(state_action.values())
			_actions = []
			for key,value in state_action.items():
				if value == _max:
					_actions.append(key)
			action = random.choice(_actions)
		else:
			action = random.choice(self.actions)
		return action
		
	def learn(self, s, a, r, s_):
		self.check_state_exist(s_)
		q_predict = self.q_table[s][a]
		if s_ != 'win':
			q_target = r + self.gamma * max(self.q_table[s_].values())
		else:
			q_target = r   # terminal
		
		self.q_table[s][a] += self.lr * (q_target - q_predict)  # update


# 环境反馈
def get_env_feedback(S, A):
	if A == 'right':
		if S == N_STATES - 2:   # 游戏结束
			S_ = 'win'
			R = 1
		else:
			S_ = S+1
			R = 0
	else:
		R = 0
		if S == 0:
			S_ = S
		else:
			S_ = S - 1
	return S_, R

# 环境更新
def update_env(S, episode, step_counter):
	env_list = ['-']*(N_STATES -1) + ['T']   # 一维移动环境
	if S == 'win':
		interaction = 'Episode %s: total_step = %s' % (episode+1,step_counter)
		print('\r{}'.format(interaction), end='')
		time.sleep(2)
		print('\r                               ', end='')
	else:
		env_list[S] = 'o'
		interaction = ''.join(env_list)
		print('\r{}'.format(interaction), end='')
		time.sleep(FRESH_TIME)

def run():
	rl = QLearning(actions=['left', 'right'])
	for episode in range(MAX_EPISODES):
		print(rl.q_table)
		step_counter = 0
		S = 0
		is_terminated = False
		update_env(S, episode, step_counter)
		while not is_terminated:
			A = rl.choose_action(S)
			S_, R = get_env_feedback(S, A)    # 决策+获取下一个状态
			rl.learn(S, A, R, S_)
			if S_ == 'win':
				is_terminated = True   # 结束这一回合
			S = S_   # 移动
			update_env(S, episode, step_counter+1)
			step_counter += 1
	return rl


RL = run()
print(RL.q_table)



posted @ 2023-01-01 16:04  Mz1  阅读(55)  评论(0编辑  收藏  举报