强化学习——价值迭代算法 悬崖漫步为例 C++

#include<bits/stdc++.h>
using namespace std;
#define N 100
#define cliff cliff_map
int row,col;
struct State{
    int next_i,next_j,flag;
    double reward;
    State(){
        next_i=next_j=flag=0;
        reward=0;
    }
};
double pi[N][N][4];
State P[N][N][4];
int cliff_map[N][N];
int direction[4][2]={{1,0},{-1,0},{0,1},{0,-1}};
double theta,gamma;
double V[N][N],V_new[N][N];
pair<int,int> change_position(pair<int,int> a,int oper){
    return make_pair(a.first+direction[oper][0],a.second+direction[oper][1]);
}
void whether_valid(int &i,int &j){
    i=min(row-1,max(0,i));
    j=min(col-1,max(0,j));
}
void input_value(State &x,int i,int j,double re,int f){
    x.flag=f;
    x.next_i=i;
    x.next_j=j;
    x.reward=re;
}
void initialization_P(){
    for(int i=0;i<row;i++){
        for(int j=0;j<col;j++){
            if(cliff_map[i][j]==0||cliff_map[i][j]==2){
                for(int k=0;k<4;k++)input_value(P[i][j][k],i,j,0,1);
                continue;
            }
            for(int k=0;k<4;k++){
                pair<int,int> a=change_position(make_pair(i,j),k);            
                int new_i=a.first,new_j=a.second;
            //    cout<<"i="<<i<<"  j="<<j<<"  k="<<k<<'\n';
                whether_valid(new_i,new_j);
            //    cout<<new_i<<' '<<new_j<<'\n';
                if(cliff[new_i][new_j]==1)
                    input_value(P[i][j][k],new_i,new_j,-1,0);
                else if(cliff[new_i][new_j]==0)
                    input_value(P[i][j][k],new_i,new_j,-100,1);
                else
                    input_value(P[i][j][k],new_i,new_j,-1,1);
            }
        }
    }
}
void policy_initialization(){
    for(int i=0;i<row;i++)
        for(int j=0;j<col;j++){
            for(int k=0;k<4;k++)
                pi[i][j][k]=0;
            V[i][j]=0;
        }
}
void clear_vnew(){
    for(int i=0;i<row;i++)
        for(int j=0;j<col;j++)
            V_new[i][j]=0;
}
void copy(){
    for(int i=0;i<row;i++){
        for(int j=0;j<col;j++){
            V[i][j]=V_new[i][j];
            cout<<V[i][j]<<' ';
        }
        cout<<'\n';
    }
    cout<<'\n';
}
void policy_evaluation(){
    policy_initialization();
    int cnt=0;
    while(1){
        double max_diff=-9999999999;
        clear_vnew();
        for(int i=0;i<row;i++){
            for(int j=0;j<col;j++){
                double MAX_QSA=-999999999;
                for(int k=0;k<4;k++){
                    MAX_QSA=max(MAX_QSA,P[i][j][k].reward+\
                    gamma*V[P[i][j][k].next_i][P[i][j][k].next_j]*\
                    (1-P[i][j][k].flag));
                }
                V_new[i][j]=MAX_QSA;
                max_diff=max(max_diff,fabs(V_new[i][j]-V[i][j]));
            }
        }
        copy();
    //    cout<<max_diff<<' ';
        if(max_diff<theta)break;
        cnt++;
    }
    cout<<"\n经过"<<cnt<<"次迭代处理"<<'\n';
}
void get_policy(){
    for(int i=0;i<row;i++){
        for(int j=0;j<col;j++){
            double MAX_QSA=-9999999;
            for(int k=0;k<4;k++){
                MAX_QSA=max(MAX_QSA,P[i][j][k].reward+\
                gamma*V[P[i][j][k].next_i][P[i][j][k].next_j]*\
                (1-P[i][j][k].flag));
            }
            int cntq=0;
            for(int k=0;k<4;k++){
                double QSA=P[i][j][k].reward+\
                gamma*V[P[i][j][k].next_i][P[i][j][k].next_j]*\
                (1-P[i][j][k].flag);
                if(QSA==MAX_QSA)cntq++;
            }
            for(int k=0;k<4;k++){
                double QSA=P[i][j][k].reward+\
                gamma*V[P[i][j][k].next_i][P[i][j][k].next_j]*\
                (1-P[i][j][k].flag);
                if(QSA==MAX_QSA)cntq++;
            }
            for(int k=0;k<4;k++){
                double QSA=P[i][j][k].reward+\
                gamma*V[P[i][j][k].next_i][P[i][j][k].next_j]*\
                (1-P[i][j][k].flag);
                if(QSA==MAX_QSA)pi[i][j][k]=1.0/cntq;
            }
        }
    }
}
void print_policy(){
    for(int i=0;i<row;i++){
        for(int j=0;j<col;j++){
            if(cliff_map[i][j]==0){
                cout<<"**** ";
                continue;
            }
            if(cliff_map[i][j]==2){
                cout<<"eeee ";
                continue;
            }
            if(pi[i][j][1])cout<<'^';else cout<<'o'; 
            if(pi[i][j][0])cout<<'v';else cout<<'o';
            if(pi[i][j][3])cout<<'<';else cout<<'o';
            if(pi[i][j][2])cout<<'>';else cout<<'o';
            cout<<' ';
        }
        cout<<endl;
    }
}
int main(){
    cout<<"row=";
    cin>>row;
    cout<<"col=";
    cin>>col;
    cout<<"map\n";
    for(int i=0;i<row;i++){
        for(int j=0;j<col;j++){
            cin>>cliff_map[i][j];
        }
    }
    theta=0.0001;gamma=0.9;
    initialization_P();
//    for(int i=0;i<row;i++){
//        for(int j=0;j<col;j++){
//            for(int k=0;k<4;k++)
//                cout<<P[i][j][k].reward;
//            cout<<" ";            
//        }
//        cout<<'\n';
//    }
    policy_evaluation();
    get_policy();
    print_policy();
}
/*
6
6
1 1 1 1 1 1
1 1 0 1 2 1
1 0 1 1 1 0
1 1 0 1 1 1
1 0 0 0 2 1
1 1 1 1 0 0
*/

 

posted @ 2022-09-07 10:02  saionjisekai  阅读(216)  评论(0编辑  收藏  举报