【强化学习】A grid world 值迭代算法 (value iterator algorithm)
强化学习——值迭代算法
代码是在
jupyter notebook
环境下编写只需要
numpy
和matplotlib
包。
此代码为学习赵世钰老师强化学习课程之后,按照公式写出来的代码,对应第四章第一节 value iterator algorithm
可以做的实验:
- 调整
gama
值观察策略的变化 - 调整惩罚值(
fa
)的大小观察策略的变化
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
class grid_world(object):
##
# isRand 是否要每次都随机生成地图
# gama 就是 gama
# fa n x n 的地图默认生成 n 个 forbidden area , 如果觉得不够可以手动指定
# 取 max(n,fa) 作为 number of forbidden area
# fval 为惩罚值
def __init__(self,n,isRand=False,gama = 0.9,fa=0,fval=-1):
self.n = n
self.fval = fval
# 生成 n x n 的网格
self.grid = np.zeros((n,n))
# 设置终点
self.grid[n-1,n-1] = 1
# 便于 a 的遍历
self.a = np.array([ [-1,0],[0,1],[1,0],[0,-1],[0,0] ])
# 便于 s 的遍历
self.s = np.array(range(n*n))
# 便于 r 的遍历
self.r = np.array([fval,0,1])
# qk(s,a)
self.q_sa = np.zeros((n*n,5))
# gama
self.gama = gama
# v 初始化 为 0
self.v = np.zeros(n*n)
# pi(s|a) 策略
self.pi_sa = np.zeros((n*n,5))
#随机种子处理
if isRand is False:
np.random.seed(8)
# 生成 n 个 forbidden area
# size = n 生成 n 个
# replace = False 生成的 n 个数字不能重复
forbidden = np.random.choice(range(n*n-1), size=max(n,fa), replace=False)
#print(forbidden)
for i in forbidden:
row,col = self.one2two(i)
self.grid[row,col] = fval
# 查看生成的矩阵
self.show_grid_default()
def train(self,k):
# 训练 k 轮
for l in range(k):
# v_k+1 的值,等到一轮结束以后,再赋值给 self.v[]
# 在一轮中暂存在这里
tev = np.zeros(self.n*self.n)
for i in range(len(self.s)):
for j in range(len(self.a)):
sum_r = 0
# 遍历 r
for r in self.r:
sum_r = sum_r + r * self.p_rsa(r,i,j)
self.q_sa[i,j] = sum_r + self.gama * self.p_ssa(i,j)
# a_star 为当前 s action value 最大的值的下标
a_star = np.argmax(self.q_sa[i])
# pi(s|a) 存储策略,先把之前存储的策略清零,再把新的策略给赋值
self.pi_sa[i,:] = 0
self.pi_sa[i,a_star] = 1
# 存储 v_k+1(s)
tev[i] = self.q_sa[i,a_star]
# 更新 state value
self.v[:] = tev[:]
# print(self.q_sa[:,:])
self.showPi()
self.showV()
# 在每个方格中显示当前的策略
def showPi(self):
data = self.v.reshape(self.n,self.n)
# 创建图像和轴对象
fig, ax = plt.subplots()
# 使用 matshow
colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
# 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)
# 在每个单元格中添加文本
for (i, j), val in np.ndenumerate(data):
teval = '↓'
for k in range(len(self.a)):
a_star = np.argmax(self.q_sa[i*self.n + j])
if a_star == 0 :
teval = '↑'
elif a_star == 1 :
teval = '→'
elif a_star == 2 :
teval = '↓'
elif a_star == 3 :
teval = '←'
else:
teval = 'o'
ax.text(j, i, f'{teval}', ha='center', va='center', color='black')
# 设置网格线
ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
# 显示图像
plt.show()
# 显示每一个方格中的 state value
def showV(self):
data = self.v.reshape(self.n,self.n)
# 创建图像和轴对象
fig, ax = plt.subplots()
# 使用 matshow
colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
# 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)
# 在每个单元格中添加文本
for (i, j), val in np.ndenumerate(data):
ax.text(j, i, f'{val:.1f}', ha='center', va='center', color='black')
# 设置网格线
ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
# 显示图像
plt.show()
# p(r|s,a)
def p_rsa(self,r,s,a):
#print(r,s,a)
row,col = self.one2two(s)
tx,ty = self.a[a]
row = row + tx
col = col + ty
if self.checkInWorld(row,col) and self.grid[row,col] == r:
return True
elif self.checkInWorld(row,col) == False and r == self.fval:
return True
else:
return False
# p(s'|s,a)
# 这里我没有遍历,因为 s_i 与 a 已经确定 那么只有唯一的一个 s' 与之对应
# 注意,越过了边界的话 v 是自己
def p_ssa(self,s,a):
row,col = self.one2two(s)
tx,ty = self.a[a]
tr = row + tx
tc = col + ty
if self.checkInWorld(tr,tc) is False:
return self.v[s]
else:
return self.v[tr * self.n + tc]
# 查看是否超出了边界
def checkInWorld(self,x,y):
if x < 0 or x >= self.n or y < 0 or y >= self.n:
return False
else:
return True
# 由一维下标,转换为二维的坐标
def one2two(self,x):
row = x // self.n
col = x % self.n
return row,col
# isShowWord 是否要在矩阵中写出数值
# isShowBar 是否显示颜色条
def show_grid_default(self,isShowWord=False,isShowBar=False):
# 创建图像和轴对象
fig, ax = plt.subplots()
# 使用 matshow
colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
# 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)
if isShowWord:
# 在每个单元格中添加文本
for (i, j), val in np.ndenumerate(self.grid):
ax.text(j, i, f'{val}', ha='center', va='center', color='black')
# 添加颜色条
if isShowBar:
fig.colorbar(cax)
# 设置网格线
ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
# 显示图像
plt.show()
grid = grid_world(10,isRand=False,fa=30,gama=0.9,fval = -100)
grid.train(500)
运行截图