井字棋小游戏AI(蒙特卡洛搜索)
刚把《强化学习》的第一部分写完,突发奇想想写一个井字棋小游戏AI,采用MCTS算法,中间采用了UCT算法作为树中策略,等概率随机作为树外策略。
代码:
#include <bits/stdc++.h> using namespace std; const int maxn = 20010; double UCT_C = 2.0; struct node { double x, y; double to_double(void) { return x / y; } void init() { x = 0; y = 0; } }; node V[maxn]; double eps = 1e-10; vector<int> Next[maxn]; vector<int> Tree[maxn]; bool ed[maxn]; char table[5][5]; mt19937 random(time(0)); int dep(int x) { int ret = 0; for (int i = 0; i < 9; i++) { if(x % 3 != 0) ret++; x /= 3; } return ret; } int rbuild(void) { int res = 0, p = 1; for (int i = 0; i < 9; i++, p *= 3) { int x = i / 3, y = i % 3, tmp = 0; if(table[x][y] == 0) tmp = 0; else if(table[x][y] == 'b') tmp = 1; else tmp = 2; res = res + p * tmp; } return res; } void build(int st) { for (int i = 0; i < 9; i++, st /= 3) { int now = st % 3; int x = i / 3, y = i % 3; if(now == 0) table[x][y] = 0; else if(now == 1) table[i / 3][i % 3] = 'b'; else table[i / 3][i % 3] = 'w'; } } vector<int> find_next(int x) { build(x); int now = x, p = 1, d = dep(x); vector<int> ret; for (int i = 0; i < 9; i++, p *= 3) { int x = i / 3, y = i % 3; if(table[x][y] == 0) { ret.push_back(now + p * ((d % 2) + 1)); } } return ret; } bool lose(int st) { build(st); for (int i = 0; i < 3; i++) { if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'w') return true; if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'w') return true; } if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'w') return true; if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'w') return true; return false; } bool vectory(int st) { build(st); for (int i = 0; i < 3; i++) { if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'b') return true; if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'b') return true; } if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'b') return true; if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'b') return true; return false; } int dfs(int x) { if(ed[x]) { if(vectory(x)) { return 1; } else if(lose(x)) { return 2; } else return 3; } int t = random() % Next[x].size(); return dfs(Next[x][t]); } double UCT(int x, double tot) { return V[x].to_double() + UCT_C * sqrt(log(tot) / V[x].y); } void MCTS(int root, int flag) { int now = root; stack<int> path; path.push(now); while(!ed[now] && Tree[now].size() == Next[now].size()) { double mx = 0; int mx_pos = 0; for (auto t : Tree[now]) { if(UCT(t, V[now].y) > mx) { mx = UCT(t, V[now].y); mx_pos = t; } } now = mx_pos; flag ^= 1; path.push(now); } if(!ed[now]) { int x = Next[now][Tree[now].size()]; Tree[now].push_back(x); flag ^= 1; V[x].init(); path.push(x); now = x; } int res = dfs(now); while(path.size()) { now = path.top(); path.pop(); if(res == 3) { V[now].x += 2; V[now].y += 2; } else if((res == 1 && flag) || (res == 2 && flag == 0)) { V[now].x += 2; V[now].y += 2; } else { V[now].y += 2; } flag ^= 1; } } int solve(int root, bool flag) { for (int i = 1; i <= 500; i++) { MCTS(root, flag); } int res = -1; double mx = -1; for (auto x : Next[root]) { if(V[x].to_double() > mx) { mx = V[x].to_double(); res = x; } } return res; } void init() { int tmp; for (int i = 0; i < 19683; i++) { bool x1 = vectory(i), x2 = lose(i); tmp = 9 - dep(i); if(x1 || x2 || (tmp == 0)) ed[i] = true; else { Next[i] = find_next(i); } } } void print_table() { printf("请落子(比如 0 0):\n"); printf("----------\n"); for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { if(table[i][j] == 0) printf(" "); else printf("%c", table[i][j]); if(j < 2) printf("-"); } if(i < 2) { printf("\n"); for (int j = 0; j < 5; j++) { if(j % 2 == 0) printf("|"); else printf(" "); } } printf("\n"); } printf("----------\n"); } void play() { int s = 0; int e = 0, l = 0, a = 0; bool flag; int round = 0; int T = 10; while(T--) { round++; printf("第%d回合:\n", round); // printf("----------\n\n\n\n\nround %d\n\n\n\n--------\n", round); memset(table, 0, sizeof(table)); int p = 0; printf("请决定执黑还是执白:\n0: 黑棋; 1: 白棋\n"); scanf("%d", &p); // print_table(); s = 0; flag = 0; print_table(); while(!ed[s]) { int x, y; if(p == 0) { scanf("%d %d", &x, &y); table[x][y] = 'b'; s = rbuild(); } else { s = solve(s, flag); build(s); } print_table(); if(ed[s]) { if(vectory(s)) { printf("黑方胜利!\n"); a++; } else { printf("平局!\n"); e++; } break; } flag ^= 1; if(p) { scanf("%d %d", &x, &y); table[x][y] = 'w'; s = rbuild(); } else { s = solve(s, flag); build(s); } print_table(); if(ed[s]) { if(lose(s)) { printf("白方胜利!\n"); l++; } else { printf("平局!\n"); e++; } break; } flag ^= 1; } } printf("AI wins: %d\nplayer wins: %d\nequals: %d\n", a, l, e); } int main() { srand(time(0)); init(); play(); }
每步的计算量在500的时候已经基本能跑出最优解了,可见MCTS比暴力搜索好很多