【强化学习】A grid world 值迭代算法 (value iterator algorithm)

强化学习——值迭代算法

代码是在 jupyter notebook 环境下编写

只需要 numpymatplotlib 包。

此代码为学习赵世钰老师强化学习课程之后,按照公式写出来的代码,对应第四章第一节 value iterator algorithm

可以做的实验:

  • 调整 gama 值观察策略的变化
  • 调整惩罚值(fa)的大小观察策略的变化
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors



class grid_world(object):

    ##
    # isRand 是否要每次都随机生成地图
    # gama 就是 gama
    # fa n x n 的地图默认生成 n 个 forbidden area , 如果觉得不够可以手动指定
    #    取 max(n,fa) 作为 number of forbidden area
    # fval 为惩罚值
    
    def __init__(self,n,isRand=False,gama = 0.9,fa=0,fval=-1):

        self.n = n
        self.fval = fval
        # 生成 n x n 的网格
        self.grid = np.zeros((n,n))
            # 设置终点
        self.grid[n-1,n-1] = 1
        # 便于 a 的遍历
        self.a = np.array([ [-1,0],[0,1],[1,0],[0,-1],[0,0] ])
        # 便于 s 的遍历
        self.s = np.array(range(n*n))
        # 便于 r 的遍历
        self.r = np.array([fval,0,1])
        # qk(s,a) 
        self.q_sa = np.zeros((n*n,5))
        # gama
        self.gama = gama
        # v 初始化 为 0
        self.v = np.zeros(n*n)
        # pi(s|a) 策略
        self.pi_sa = np.zeros((n*n,5))

        
        
        #随机种子处理
        if isRand is False:
            np.random.seed(8)
        
        # 生成 n 个 forbidden area
            # size = n 生成 n 个
            # replace = False 生成的 n 个数字不能重复
        forbidden = np.random.choice(range(n*n-1), size=max(n,fa), replace=False)
        #print(forbidden)
        for i in forbidden:
            row,col = self.one2two(i)
            self.grid[row,col] = fval
        # 查看生成的矩阵
        self.show_grid_default()
       


    def train(self,k):

        # 训练 k 轮
        for l in range(k):

            # v_k+1 的值,等到一轮结束以后,再赋值给 self.v[]
            # 在一轮中暂存在这里
            tev = np.zeros(self.n*self.n)

            for i in range(len(self.s)):
                for j in range(len(self.a)):

                    sum_r = 0
                    # 遍历 r
                    for r in self.r:
                        sum_r = sum_r + r * self.p_rsa(r,i,j)

                    self.q_sa[i,j] = sum_r + self.gama * self.p_ssa(i,j)
                    
                # a_star 为当前 s action value 最大的值的下标
                a_star = np.argmax(self.q_sa[i])
                # pi(s|a) 存储策略,先把之前存储的策略清零,再把新的策略给赋值
                self.pi_sa[i,:] = 0
                self.pi_sa[i,a_star] = 1
                # 存储 v_k+1(s)
                tev[i] = self.q_sa[i,a_star]

            # 更新 state value
            self.v[:] = tev[:]
        #    print(self.q_sa[:,:])
        self.showPi()
        self.showV()

    
    # 在每个方格中显示当前的策略
    def showPi(self):
        data = self.v.reshape(self.n,self.n)
        # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        
        # 在每个单元格中添加文本
        for (i, j), val in np.ndenumerate(data):
            teval = '↓'
            for k in range(len(self.a)):
                a_star = np.argmax(self.q_sa[i*self.n + j])
                if a_star == 0 :
                    teval = '↑'
                elif a_star == 1 :
                    teval = '→'
                elif a_star == 2 :
                    teval = '↓'
                elif  a_star == 3 :
                    teval = '←'
                else:
                    teval = 'o'
                
            
            ax.text(j, i, f'{teval}', ha='center', va='center', color='black')
            


        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        

    # 显示每一个方格中的  state value    
    def showV(self):
        data = self.v.reshape(self.n,self.n)
         # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        
        # 在每个单元格中添加文本
        for (i, j), val in np.ndenumerate(data):
            ax.text(j, i, f'{val:.1f}', ha='center', va='center', color='black')
            


        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        
    

    # p(r|s,a)
    def p_rsa(self,r,s,a):
        #print(r,s,a)
        row,col = self.one2two(s)
        tx,ty = self.a[a] 
        row = row + tx
        col = col + ty
        if self.checkInWorld(row,col) and self.grid[row,col] == r:
            return True
        elif self.checkInWorld(row,col) == False and r == self.fval:
            return True
        else:
            return False

    # p(s'|s,a)
    # 这里我没有遍历,因为 s_i 与 a 已经确定 那么只有唯一的一个 s' 与之对应
    # 注意,越过了边界的话 v 是自己
    def p_ssa(self,s,a):
        row,col = self.one2two(s)
        tx,ty = self.a[a] 
        tr = row + tx
        tc = col + ty
        if self.checkInWorld(tr,tc) is False:
            return self.v[s]
        else:
            return self.v[tr * self.n + tc]

    # 查看是否超出了边界
    def checkInWorld(self,x,y):
        if x < 0 or x >= self.n or y < 0 or y >= self.n:
            return False
        else:
            return True

    # 由一维下标,转换为二维的坐标
    def one2two(self,x):
        row = x // self.n
        col = x % self.n
        return row,col
        

    # isShowWord 是否要在矩阵中写出数值
    # isShowBar 是否显示颜色条
    def show_grid_default(self,isShowWord=False,isShowBar=False):
        # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        if isShowWord:
            # 在每个单元格中添加文本
            for (i, j), val in np.ndenumerate(self.grid):
                ax.text(j, i, f'{val}', ha='center', va='center', color='black')
            
        # 添加颜色条
        if isShowBar:
            fig.colorbar(cax)

        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        
        
        
        
        
grid = grid_world(10,isRand=False,fa=30,gama=0.9,fval = -100)
grid.train(500)

运行截图

image-20240514155508036image-20240514155624859image-20240514155712520

posted @ 2024-05-14 16:02  Hoppz  阅读(41)  评论(0编辑  收藏  举报