井字棋小游戏AI(蒙特卡洛搜索)

刚把《强化学习》的第一部分写完,突发奇想想写一个井字棋小游戏AI,采用MCTS算法,中间采用了UCT算法作为树中策略,等概率随机作为树外策略。

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 20010;
double UCT_C = 2.0;
struct node {
	double x, y;
		
	double to_double(void) {
		return x / y;
	}
	
	void init() {
		x = 0;
		y = 0; 
	}
	
};
node V[maxn];
double eps = 1e-10;
vector<int> Next[maxn];
vector<int> Tree[maxn];
bool ed[maxn];
char table[5][5];
mt19937 random(time(0));

int dep(int x) {
	int ret = 0;
	for (int i = 0; i < 9; i++) {
		if(x % 3 != 0) ret++;
		x /= 3;
	}
	return ret;
}
int rbuild(void) {
	int res = 0, p = 1;
	for (int i = 0; i < 9; i++, p *= 3) {
		int x = i / 3, y = i % 3, tmp = 0;
		if(table[x][y] == 0) tmp = 0;
		else if(table[x][y] == 'b') tmp = 1;
		else tmp = 2;
		res = res + p * tmp;
	}
	return res;
}

void build(int st) {
	for (int i = 0; i < 9; i++, st /= 3) {
		int now = st % 3;
		int x = i / 3, y = i % 3;
		if(now == 0) table[x][y] = 0;
		else if(now == 1) table[i / 3][i % 3] = 'b';
		else table[i / 3][i % 3] = 'w';
	}
}

vector<int> find_next(int x) {
	build(x);
	int now = x, p = 1, d = dep(x);
	vector<int> ret;
	for (int i = 0; i < 9; i++, p *= 3) {
		int x = i / 3, y = i % 3;
		if(table[x][y] == 0) {
			ret.push_back(now + p * ((d % 2) + 1));
		} 
	}
	return ret;
}

bool lose(int st) {
	build(st);
	for (int i = 0; i < 3; i++) {
		if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'w') return true;
		if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'w') return true;
	}
	if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'w') return true;
	if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'w') return true;
	return false;
}

bool vectory(int st) {
	build(st);
	for (int i = 0; i < 3; i++) {
		if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'b') return true;
		if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'b') return true;
	}
	if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'b') return true;
	if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'b') return true;
	return false;
}

int dfs(int x) {
	if(ed[x]) {
		if(vectory(x)) {
			return 1;
		}
		else if(lose(x)) {
			return 2;
		}
		else return 3;
	}
	int t = random() % Next[x].size();
	return dfs(Next[x][t]);
}

double UCT(int x, double tot) {
	return V[x].to_double() + UCT_C * sqrt(log(tot) / V[x].y);
}

void MCTS(int root, int flag) {
	int now = root;
	stack<int> path;
	path.push(now);
	while(!ed[now] && Tree[now].size() == Next[now].size()) {
		double mx = 0;
		int mx_pos = 0;
		for (auto t : Tree[now]) {
			if(UCT(t, V[now].y) > mx) {
				mx = UCT(t, V[now].y);
				mx_pos = t;
			}
		}
		now = mx_pos;
		flag ^= 1;
		path.push(now);
	}
	if(!ed[now]) {
		int x = Next[now][Tree[now].size()];
		Tree[now].push_back(x);
		flag ^= 1;
		V[x].init();
		path.push(x);
		now = x;
	}
	int res = dfs(now);
	while(path.size()) {
		now = path.top();
		path.pop();
		if(res == 3) {
			V[now].x += 2;
			V[now].y += 2;
		}
		else if((res == 1 && flag) || (res == 2 && flag == 0)) {
			V[now].x += 2;
			V[now].y += 2;
		} else {
			V[now].y += 2;
		}
		flag ^= 1;
	}
}

int solve(int root, bool flag) {
	for (int i = 1; i <= 500; i++) {
		MCTS(root, flag); 
	}
	int res = -1;
	double mx = -1;
	for (auto x : Next[root]) {
		if(V[x].to_double() > mx) {
			mx = V[x].to_double();
			res = x;
		}
	}
	return res;
}

void init() {
	int tmp;
	for (int i = 0; i < 19683; i++) {
		bool x1 = vectory(i), x2 = lose(i);
		tmp = 9 - dep(i);
		if(x1 || x2 || (tmp == 0))
			ed[i] = true;
		else {
			Next[i] = find_next(i);
		}
	}
}

void print_table() {
	printf("请落子(比如 0 0):\n");
	printf("----------\n");
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			if(table[i][j] == 0) printf(" ");
			else printf("%c", table[i][j]);
			if(j < 2) printf("-"); 
		}
		if(i < 2) {
			printf("\n");
			for (int j = 0; j < 5; j++) {
				if(j % 2 == 0) printf("|");
				else printf(" ");
			}
		}
		
		printf("\n");
	}
	printf("----------\n");
}

void play() {
	int s = 0;
	int e = 0, l = 0, a = 0;
	bool flag;
	int round = 0;
	int T = 10;
	while(T--) {
		round++;
		printf("第%d回合:\n", round);
//		printf("----------\n\n\n\n\nround %d\n\n\n\n--------\n", round);
		memset(table, 0, sizeof(table));
		int p = 0;
		printf("请决定执黑还是执白:\n0: 黑棋; 1: 白棋\n");
		scanf("%d", &p);
//		print_table();
		s = 0;
		flag = 0;
		print_table();
		while(!ed[s]) {
			int x, y;
			if(p == 0) {
				scanf("%d %d", &x, &y);
				table[x][y] = 'b';
				s = rbuild();
			} else {
				s = solve(s, flag);
				build(s);
			}
			print_table();
			if(ed[s]) {
				if(vectory(s)) {
					printf("黑方胜利!\n");
					a++;
				}
				else {
					printf("平局!\n");
					e++;
				}	
				break;
			}
			flag ^= 1;
			if(p) {
				scanf("%d %d", &x, &y);
				table[x][y] = 'w';
				s = rbuild();
			} else {
				s = solve(s, flag);
				build(s);
			}
			print_table();
			if(ed[s]) {
				if(lose(s)) {
					printf("白方胜利!\n");
					l++;
				}
				else {
					printf("平局!\n");
					e++;
				}
				break;
			}
			flag ^= 1;
		}
	}
	printf("AI wins: %d\nplayer wins: %d\nequals: %d\n", a, l, e);
}

int main() {
	srand(time(0));
	init();
	play();
}

  每步的计算量在500的时候已经基本能跑出最优解了,可见MCTS比暴力搜索好很多

posted @ 2021-01-08 00:09  维和战艇机  阅读(496)  评论(0编辑  收藏  举报