golang蒙特卡洛树算法实现五子棋AI
已经实现蒙特卡洛树算法的通用逻辑,只需要对应结构体实现相关接口就可以直接使用该算法。
优化算法主要优化GetActions
生成下一步动作,要尽可能少,去掉无意义的动作。
以及优化ActionPolicy
从众多动作挑选比较优秀的动作。对应五子棋就是执行该动作后当前局面评分最高。
package main
import (
"fmt"
"math"
"math/rand"
"strings"
"time"
)
func main() {
var (
board = NewQuZiQi(15)
x, y int
)
board.Print()
for board.IsTerminal() == 0 {
board = Search(time.Second*10, board).(*WuZiQi)
board.Print()
if board.IsTerminal() == 1 {
fmt.Println("电脑赢了")
return
}
for {
fmt.Print("轮到您执棋,请输入坐标: ")
_, _ = fmt.Scanln(&x, &y)
x--
y--
if x < 0 || y < 0 || x >= board.size || y >= board.size {
fmt.Println("您输入的数据超出棋盘范围")
} else if board.board[x][y] > 0 {
fmt.Println("该位置已有棋子")
} else {
board.board[x][y] = 2
board.player = 1 // 下一步该电脑下
break
}
}
board.Print()
if board.IsTerminal() == 2 {
fmt.Println("你赢了")
return
}
}
}
// WuZiQi 五子棋游戏
type WuZiQi struct {
size int // 棋盘大小
board [][]int // 棋盘状态
player int // 1: 电脑落子,2: 玩家落子
}
func NewQuZiQi(size int) *WuZiQi {
w := &WuZiQi{
size: size,
board: make([][]int, size),
player: 1,
}
for i := 0; i < size; i++ {
w.board[i] = make([]int, size)
}
size /= 2
// 默认中间落一个棋子
// 0: 表示没有落子,1: 表示电脑,2: 表示玩家
w.board[size][size] = 2
return w
}
func (w *WuZiQi) Print() {
var (
str strings.Builder
num = func(n int) {
a, b := n/10, n%10
if a > 0 {
str.WriteByte(byte(a + '0'))
} else {
str.WriteByte(' ') // 1位数前面加空格
}
str.WriteByte(byte(b + '0'))
}
)
str.WriteString(" ")
for i := 1; i <= w.size; i++ {
str.WriteByte(' ')
num(i)
}
str.WriteByte('\n')
for i := 0; i < w.size; i++ {
str.WriteString(" ")
for j := 0; j < w.size; j++ {
str.WriteString(" __")
}
str.WriteByte('\n')
num(i + 1)
str.WriteByte(' ')
for j := 0; j < w.size; j++ {
str.WriteByte('|')
switch w.board[i][j] {
case 0:
str.WriteByte(' ')
case 1:
str.WriteByte('O')
case 2:
str.WriteByte('X')
}
str.WriteByte(' ')
}
str.WriteString("|\n")
}
str.WriteString(" ")
for i := 0; i < w.size; i++ {
str.WriteString(" __")
}
fmt.Println(str.String())
}
func (w *WuZiQi) IsTerminal() int {
full := -1 // 没有空位且都没赢
for i := 0; i < w.size; i++ {
for j := 0; j < w.size; j++ {
if wc := w.board[i][j]; wc == 0 {
full = 0 // 还有空位,没结束
} else {
// 向右
cnt, x, y := 1, 0, j+1
for ; y < w.size && w.board[i][y] == wc; y++ {
cnt++
}
if cnt >= 5 {
return wc
}
// 向下
cnt, x = 1, i+1
for ; x < w.size && w.board[x][j] == wc; x++ {
cnt++
}
if cnt >= 5 {
return wc
}
// 向右下
cnt, x, y = 1, i+1, j+1
for ; x < w.size && y < w.size && w.board[x][y] == wc; x, y = x+1, y+1 {
cnt++
}
if cnt >= 5 {
return wc
}
// 向左下
cnt, x, y = 1, i+1, j-1
for ; x < w.size && y >= 0 && w.board[x][y] == wc; x, y = x+1, y-1 {
cnt++
}
if cnt >= 5 {
return wc
}
}
}
}
return full
}
func (w *WuZiQi) Result(state int) float64 {
switch state {
case -1:
return 0 // 都没赢且没空位
case 1:
return -1 // 电脑赢了
case 2:
return +1 // 玩家赢了
default:
return 0 // 都没赢且有空位
}
}
func (w *WuZiQi) GetActions() (res []any) {
// todo 敌方上一步落子附近才是最优搜索范围
// 某个落子必胜,则直接落子,如果某个落子让对手所有落子都必败则直接落子
// 因此后续动作进一步缩小范围
// 可以使用hash判断棋盘状态
m := map[[2]int]struct{}{} // 用于去重
for i := 0; i < w.size; i++ {
for j := 0; j < w.size; j++ {
if w.board[i][j] == 0 || w.board[i][j] == w.player {
continue // 跳过空位和己方棋子
}
x0, x1, y0, y1 := i-2, i+2, j-2, j+2
for ii := x0; ii < x1; ii++ {
for jj := y0; jj < y1; jj++ {
if ii >= 0 && jj >= 0 && ii < w.size && jj < w.size &&
w.board[ii][jj] == 0 {
p := [2]int{ii, jj}
_, ok := m[p]
if !ok {
// 在棋子周围2格范围的空位加到结果中
// 超过2格的空位落子的意义不大
res = append(res, p)
m[p] = struct{}{}
}
}
}
}
}
}
return
}
func (w *WuZiQi) ActionPolicy(action []any) any {
// 目前随机选一个动作,应该是好方案先选出来
return action[rand.Intn(len(action))]
}
func (w *WuZiQi) Action(action any) TreeState {
wn := &WuZiQi{
size: w.size,
board: make([][]int, w.size),
player: 3 - w.player, // 切换电脑和玩家
}
for i := 0; i < w.size; i++ {
wn.board[i] = make([]int, w.size)
for j := 0; j < w.size; j++ {
wn.board[i][j] = w.board[i][j]
}
}
ac := action.([2]int) // 在该位置落子
wn.board[ac[0]][ac[1]] = w.player
return wn
}
// MonteCarloTree 下面是算法部分
// 你的对象只需要提供TreeState所有接口,就可以直接使用
// https://github.com/int8/monte-carlo-tree-search
// https://blog.csdn.net/masterhero666/article/details/126325506
type (
TreeState interface {
IsTerminal() int // 0: 未结束,其他为自定义状态
Result(int) float64 // 计算分数,传入IsTerminal结果
GetActions() []any // 获取所有合法动作, todo 考虑获取不到动作时如何处理
ActionPolicy([]any) any // 按策略挑选一个动作
Action(any) TreeState // 执行动作生成子节点
}
McTreeNode struct {
parent *McTreeNode
children []*McTreeNode
score float64
visitCount float64
untriedActions []any
nodeState TreeState
}
)
func Search(simulate any, state TreeState, discount ...float64) TreeState {
var (
root = &McTreeNode{nodeState: state}
leaf *McTreeNode
dp = 1.4 // 折扣参数默认值
)
if len(discount) > 0 {
dp = discount[0]
}
var loop func() bool
switch s := simulate.(type) {
case int:
loop = func() bool {
s-- // 模拟指定次数后退出
return s >= 0
}
case time.Duration:
ts := time.Now().Add(s) // 超过指定时间后退出
loop = func() bool { return time.Now().Before(ts) }
case func() bool:
loop = s // 或者由外部指定模拟结束方案
default:
panic(simulate)
}
for loop() {
leaf = root.treePolicy(dp)
result, curState := 0, leaf.nodeState
for {
if result = curState.IsTerminal(); result != 0 {
break // 结束状态
}
// 根据该节点状态生成所有合法动作
all := curState.GetActions()
// 按照某种策略选出1个动作,不同于expand的顺序取出
one := curState.ActionPolicy(all)
// 执行该动作,重复该过程,直到结束
curState = curState.Action(one)
}
// 根据结束状态计算结果,将该结果反向传播
leaf.backPropagate(curState.Result(result))
}
return root.chooseBestChild(dp).nodeState // 选择最优子节点
}
func (cur *McTreeNode) chooseBestChild(c float64) *McTreeNode {
var (
idx = 0
maxValue = -math.MaxFloat64
childValue float64
)
for i, child := range cur.children {
childValue = (child.score / child.visitCount) +
c*math.Sqrt(math.Log(cur.visitCount)/child.visitCount)
if childValue > maxValue {
maxValue = childValue
idx = i // 选择分值最高的子节点
}
}
return cur.children[idx]
}
func (cur *McTreeNode) backPropagate(result float64) {
nodeCursor := cur
for nodeCursor.parent != nil {
nodeCursor.score += result
nodeCursor.visitCount++ // 反向传播,增加访问次数,更新分数
nodeCursor = nodeCursor.parent
}
nodeCursor.visitCount++
}
func (cur *McTreeNode) expand() *McTreeNode {
res := cur.untriedActions[0] // 返回1个未经尝试动作
cur.untriedActions = cur.untriedActions[1:]
child := &McTreeNode{
parent: cur, // 当前节点按顺序弹出1个动作,执行动作生成子节点
nodeState: cur.nodeState.Action(res),
}
cur.children = append(cur.children, child)
return child
}
func (cur *McTreeNode) treePolicy(discountParamC float64) *McTreeNode {
nodeCursor := cur // 一直循环直到结束
for nodeCursor.nodeState.IsTerminal() == 0 {
if nodeCursor.untriedActions == nil {
// 只会初始化1次,找出该节点所有动作
nodeCursor.untriedActions = nodeCursor.nodeState.GetActions()
}
if len(nodeCursor.untriedActions) > 0 {
return nodeCursor.expand() // 存在未处理动作则添加子节点
}
// 处理完动作,选择最好子节点继续往下处理
nodeCursor = nodeCursor.chooseBestChild(discountParamC)
}
return nodeCursor
}