【RL系列】马尔可夫决策过程——Gambler's Problem
Gambler's Problem,即“赌徒问题”,是一个经典的动态编程里值迭代应用的问题。
在一个掷硬币游戏中,赌徒先下注,如果硬币为正面,赌徒赢回双倍,若是反面,则输掉赌注。赌徒给自己定了一个目标,本金赢到100块或是输光就结束游戏。找到一个关于本金与赌注之间关系的策略使得赌徒最快赢到100块。状态s = {1, 2, 3...., 99, 100},动作a = {1, 2, 3, ...., min(s, 100 - s)}。奖励设置:只有当赌徒赢到100块时奖励+1,其余状态奖励为0。
这个问题并不难,最优policy一定是min(s, 100-s),这里就不分析了,直接给出计算程序
clear clc %% Initialize Q = zeros(101); ActionProb = Q + 1/100; V = zeros(1, 101); R = V; R(1, 101) = 1; V = R; hp = 0.4; i = 0; delta = 100; gamma = 0.5; capital = [1:99]; num = 1; %% Value Iteration while(num < 10) while(i < num) delta = 0; capital = [1:99]; for state = [1:99] actions = [1:min(capital(state), 100 - capital(state))]; PossibleStateLose = capital(state) - actions + 1; PossibleStateWin = capital(state) + actions + 1; %Q(state + 1, actions) = gamma*(hp*V(PossibleStateWin) + (1 - hp)*V(PossibleStateLose)) + R(PossibleStateWin) + R(PossibleStateLose); Q(state + 1, actions) = hp*V(PossibleStateWin) + (1 - hp)*V(PossibleStateLose); [MAX index] = max(Q(state + 1, :)); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %Softmax Policy: %ActionProb(state, :) = 0; %ActionProb(state, :) = exp(Q(state, :)/0.02)/sum(exp(Q(state, :)/0.02)); %R(state + 1) = ActionProb(state, :)*Q(state, :)'; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% V(state + 1) = MAX; end i = i + 1; end plot(V, 'LineWidth', 2) hold on num = num + 1; grid on end %% figure for state = 1:100 [MAX index] = max(Q(state, :)); Map(state) = index; plot(state, index, 'bo') hold on end %%Test Part iter = 1; count = zeros(1, 100); flag = count; Mflag = zeros(1, 100); while(iter < 1000) Mflag = zeros(1, 100); Mcount = Mflag; for state = 1:100 capital = state; while(1) if(capital >= 100) break end stake = Map(capital); %stake = min(capital, 100 - capital); if(rand < 0.4) capital = capital + stake; else capital = capital - stake; end if(capital <= 0) flag(state) = flag(state) + 1; Mflag(state) = Mflag(state) + 1; break else count(state) = count(state) + 1; Mcount(state) = Mcount(state) + 1; end end end %figure %plot(find(flag~=1), count(find(flag ~= 1)), 'bo') FT(iter) = sum(Mflag)/100; ST(iter) = mean(Mcount(find(Mflag ~= 1))); iter = iter + 1; end figure plot(1 - flag/1000, 'bo') figure plot(count/1000) mean(1-FT) mean(ST)