世界的无名小卒,爸|

AAA建材王师傅

园龄:2年5个月粉丝:4关注:1

DeepSeek MOE 代码实现

import torch
from torch import nn

# ExpertNetwork 类:定义每个专家的网络
class ExpertNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.hidden_size = hidden_size                # 输入和输出的特征维度
        self.intermediate_size = intermediate_size    # 中间层的大小

        # 定义两个线性层
        self.linear1 = nn.Linear(hidden_size, intermediate_size)  # (batch_size, hidden_size) -> (batch_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)  # (batch_size, intermediate_size) -> (batch_size, hidden_size)

    def forward(self, x):
        x = self.linear1(x)                   # 经过第一个线性层
        x = nn.functional.relu(x)             # ReLU 激活函数
        output = self.linear2(x)              # 经过第二个线性层
        return output                         # 返回输出,尺寸为 (batch_size, hidden_size)

# Router 类:用于选择每个输入数据的专家
class Router(nn.Module):
    def __init__(self, hidden_size, expert_num, top_k):
        super().__init__()
        self.router = nn.Linear(hidden_size, expert_num)  # (batch_size, hidden_size) -> (batch_size, expert_num)
        self.top_k = top_k                    # 每次选择 top_k 个专家
        self.hidden_size = hidden_size        # 输入的特征维度

    def forward(self, x):
        x = x.view(-1, self.hidden_size)           # 展平输入,尺寸变为 (batch_size * seq_len, hidden_size)
        x = self.router(x)                         # 通过 router 得到每个专家的选择权重,尺寸为 (batch_size * seq_len, expert_num)
        x = nn.functional.softmax(x, dim=-1)       # 使用 softmax 转换为概率分布,尺寸为 (batch_size * seq_len, expert_num)
        topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=-1, sorted=False)  # 选择 top_k 个专家,尺寸为 (batch_size * seq_len, top_k)
        
        # 权重归一化,使得它们的和为 1,尺寸为 (batch_size * seq_len, top_k)
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
        
        return topk_weight, topk_idx  # 返回选择的 top_k 权重和专家索引

# MOELayer 类:实现混合专家层
class MOELayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.expert_num = expert_num
        self.top_k = top_k

        # 定义多个专家网络
        self.experts = nn.ModuleList(
            [ExpertNetwork(self.hidden_size, self.intermediate_size) for _ in range(self.expert_num)]
        )

        # 定义路由器
        self.router = Router(self.hidden_size, self.expert_num, self.top_k)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()  # 获取输入的尺寸,(batch_size, seq_len, hidden_size)
        token_num = batch_size * seq_len  # 计算总的 token 数量,(batch_size * seq_len)
        x_flat = x.view(token_num, self.hidden_size)  # 展平输入,尺寸为 (batch_size * seq_len, hidden_size)

        # 通过路由器获取 top_k 权重和索引
        topk_weight, topk_idx = self.router(x)
        
        # 初始化输出为零张量,尺寸为 (batch_size * seq_len, hidden_size)
        output = torch.zeros_like(x_flat)

        # 对于每个 token,选择 top_k 个专家进行计算
        for token_idx in range(token_num):  # 遍历所有 token
            for expert_idx in range(self.top_k):  # 遍历每个 token 的 top_k 个专家
                # 选择相应的专家,并计算其输出
                expert = self.experts[topk_idx[token_idx][expert_idx]]
                output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx])  # 加权输出

        # 将输出恢复为原始形状 (batch_size, seq_len, hidden_size)
        output = output.view(batch_size, seq_len, self.hidden_size)
        return output

# 设置超参数
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2

# 输入张量,尺寸为 (batch_size, seq_len, hidden_size)
inputs = torch.randn((2, 11, 4096))

# 实例化 MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)

# 计算输出
outputs = moe_layer(inputs)

# 输出结果的尺寸
print(outputs.size())  # 输出尺寸: (batch_size, seq_len, hidden_size)

本文作者:AAA建材王师傅

本文链接:https://www.cnblogs.com/zz-w/p/18749057

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   AAA建材王师傅  阅读(12)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起