MCTS(蒙特卡洛树搜索)

蒙特卡洛树搜索(Monte Carlo Tree Search,MCTS)是一种在决策过程中进行最优选择的算法,尤其在博弈类游戏和具有复杂状态空间的问题中表现出色。

基本概念

蒙特卡洛树搜索结合了蒙特卡洛方法的随机采样特性和树搜索的结构,用于在大规模的状态空间中寻找最优策略。它通过模拟大量的随机游戏来评估每个可能的行动,随着模拟次数的增加,算法能够逐渐收敛到最优解。其核心思想是在有限的时间内,通过对状态空间的部分采样和评估,构建一棵搜索树,从而找到当前状态下的最优行动。

算法步骤

蒙特卡洛树搜索主要包含四个阶段,并且会不断循环这四个阶段,直到达到停止条件(如达到最大模拟次数或时间限制)。

1. 选择(Selection)

从搜索树的根节点开始,根据一定的策略(如 UCB1 公式)选择子节点,直到到达一个未被完全扩展的节点。UCB1 公式如下: UCBi(t)=X¯i(t)+c2lntNi(t) 其中,X¯i(t) 是第 i 个节点的平均奖励,c 是一个超参数,用于控制探索和利用的平衡,t 是当前的总模拟次数,Ni(t) 是第 i 个节点的访问次数。

该阶段的目的是在已经探索过的节点中,选择最有潜力的路径进行进一步探索。

2. 扩展(Expansion)

当到达一个未被完全扩展的节点时,随机选择一个未被访问过的子节点并将其添加到搜索树中。

这个阶段增加了搜索树的规模,使得算法能够探索更多的可能行动。

3. 模拟(Simulation)

从新扩展的节点开始,进行一次随机模拟游戏,直到游戏结束。在模拟过程中,每个行动都是随机选择的。

通过模拟游戏,得到一个最终的奖励结果(如胜利、失败或平局对应的奖励值)。

4. 反向传播(Backpropagation)

将模拟得到的奖励结果沿着从新扩展的节点到根节点的路径进行反向传播,更新路径上每个节点的访问次数和平均奖励。

访问次数加 1,平均奖励根据新的奖励结果进行更新,使得搜索树中的节点信息能够反映出该节点的潜在价值。 

 

复制代码
import math
import random

# 定义井字棋游戏状态类
class TicTacToeState:
    def __init__(self, board=None, player=1):
        # 如果没有传入棋盘状态,初始化为全 0 的 9 个元素的列表,代表空棋盘
        if board is None:
            self.board = [0] * 9
        else:
            # 复制传入的棋盘状态,避免修改原始对象
            self.board = board.copy()
        # 当前玩家,1 或 2
        self.player = player

    def get_possible_actions(self):
        # 遍历棋盘,找出值为 0 的位置,这些位置表示可以落子的地方
        return [i for i, val in enumerate(self.board) if val == 0]

    def take_action(self, action):
        # 复制当前棋盘状态,避免修改原始棋盘
        new_board = self.board.copy()
        # 在指定位置落子,落子的值为当前玩家编号
        new_board[action] = self.player
        # 切换玩家,1 变为 2,2 变为 1
        new_player = 3 - self.player
        # 返回新的游戏状态对象
        return TicTacToeState(new_board, new_player)

    def is_terminal(self):
        # 定义所有可能的获胜组合,包括行、列和对角线
        winning_combinations = [
            [0, 1, 2], [3, 4, 5], [6, 7, 8],  #
            [0, 3, 6], [1, 4, 7], [2, 5, 8],  #
            [0, 4, 8], [2, 4, 6]  # 对角线
        ]
        # 遍历所有获胜组合
        for combination in winning_combinations:
            # 如果某一获胜组合的三个位置的值相同且不为 0,说明有玩家获胜,游戏结束
            if self.board[combination[0]] == self.board[combination[1]] == self.board[combination[2]] != 0:
                return True
        # 如果棋盘上没有 0 了,说明棋盘已满,游戏结束
        if 0 not in self.board:
            return True
        # 否则游戏未结束
        return False

    def get_reward(self):
        # 定义所有可能的获胜组合,包括行、列和对角线
        winning_combinations = [
            [0, 1, 2], [3, 4, 5], [6, 7, 8],  #
            [0, 3, 6], [1, 4, 7], [2, 5, 8],  #
            [0, 4, 8], [2, 4, 6]  # 对角线
        ]
        # 遍历所有获胜组合
        for combination in winning_combinations:
            # 如果玩家 1 在某一获胜组合的三个位置都落子,玩家 1 获胜,返回 1
            if self.board[combination[0]] == self.board[combination[1]] == self.board[combination[2]] == 1:
                return 1
            # 如果玩家 2 在某一获胜组合的三个位置都落子,玩家 2 获胜,返回 -1
            elif self.board[combination[0]] == self.board[combination[1]] == self.board[combination[2]] == 2:
                return -1
        # 平局或未结束游戏,返回 0
        return 0


# 定义蒙特卡洛树搜索节点类
class MCTSNode:
    def __init__(self, state, parent=None):
        # 当前节点对应的游戏状态
        self.state = state
        # 当前节点的父节点
        self.parent = parent
        # 当前节点的子节点列表
        self.children = []
        # 当前节点的访问次数
        self.visits = 0
        # 当前节点累计的奖励值
        self.reward = 0
        # 当前状态下还未尝试过的行动列表
        self.untried_actions = state.get_possible_actions()

    def expand(self):
        # 从还未尝试过的行动中取出一个
        action = self.untried_actions.pop()
        # 执行该行动,得到新的游戏状态
        next_state = self.state.take_action(action)
        # 创建新的节点,其父节点为当前节点
        child_node = MCTSNode(next_state, self)
        # 将新节点添加到当前节点的子节点列表中
        self.children.append(child_node)
        # 返回新创建的子节点
        return child_node

    def is_fully_expanded(self):
        # 判断是否还有未尝试过的行动,如果没有则表示该节点已完全扩展
        return len(self.untried_actions) == 0

    def is_terminal_node(self):
        # 判断当前节点对应的游戏状态是否为终止状态
        return self.state.is_terminal()

    def rollout(self):
        # 获取当前节点的游戏状态
        current_state = self.state
        # 当游戏未结束时,继续模拟
        while not current_state.is_terminal():
            # 获取当前状态下所有可能的行动
            possible_actions = current_state.get_possible_actions()
            # 随机选择一个行动
            action = random.choice(possible_actions)
            # 执行该行动,得到新的游戏状态
            current_state = current_state.take_action(action)
        # 返回最终游戏状态的奖励值
        return current_state.get_reward()

    def backpropagate(self, result):
        # 当前节点的访问次数加 1
        self.visits += 1
        # 当前节点的累计奖励值加上本次模拟的结果
        self.reward += result
        # 如果当前节点有父节点,将结果反向传播给父节点
        if self.parent:
            self.parent.backpropagate(result)

    def best_child(self, c_param=1.4):
        # 计算每个子节点的 UCB 值
        choices_weights = [
            (child.reward / child.visits) + c_param * math.sqrt((2 * math.log(self.visits) / child.visits))
            for child in self.children
        ]
        # 找到 UCB 值最大的子节点的索引
        index = choices_weights.index(max(choices_weights))
        # 返回 UCB 值最大的子节点
        return self.children[index]


# 定义蒙特卡洛树搜索类
class MCTS:
    def __init__(self, state):
        # 创建根节点,对应初始游戏状态
        self.root = MCTSNode(state)

    def search(self, num_simulations):
        # 进行指定次数的模拟
        for _ in range(num_simulations):
            # 选择一个节点进行扩展或模拟
            node = self.select_node()
            # 如果该节点不是终止节点
            if not node.is_terminal_node():
                # 对该节点进行扩展,创建一个新的子节点
                node = node.expand()
            # 从该节点开始进行模拟,得到模拟结果
            reward = node.rollout()
            # 将模拟结果反向传播到根节点
            node.backpropagate(reward)
        # 返回根节点 UCB 值最大的子节点,即当前认为的最优行动对应的节点
        return self.root.best_child(c_param=0)

    def select_node(self):
        # 从根节点开始选择节点
        current_node = self.root
        # 当当前节点已完全扩展且不是终止节点时,继续选择子节点
        while current_node.is_fully_expanded() and not current_node.is_terminal_node():
            current_node = current_node.best_child()
        # 返回最终选择的节点
        return current_node


# 主函数,测试 MCTS
if __name__ == "__main__":
    # 初始化井字棋的初始状态
    initial_state = TicTacToeState()
    # 创建蒙特卡洛树搜索对象
    mcts = MCTS(initial_state)
    # 进行 1000 次模拟搜索,得到最优行动对应的节点
    best_move_node = mcts.search(num_simulations=1000)
    # 找出从初始状态到最优状态发生变化的位置,即为最优行动
    best_move = [i for i, val in enumerate(initial_state.board) if val != best_move_node.state.board[i]][0]
    # 打印最优行动
    print(f"Best move: {best_move}")
复制代码

 

posted @   xd_xumaomao  阅读(45)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示