MMoE核心代码
class MMoE_Layer(tf.keras.layers.Layer): def __init__(self,expert_dim,n_expert,n_task): super(MMoE_Layer, self).__init__() self.n_task = n_task # 专家个数 self.expert_layer = [Dense(expert_dim,activation = 'relu') for i in range(n_expert)] # 任务个数,和目标对应 self.gate_layers = [Dense(n_expert,activation = 'softmax') for i in range(n_task)] def call(self,x): # x表示为向量 # 构建多个专家网络 E_net = [expert(x) for expert in self.expert_layer] E_net = Concatenate(axis = 1)([e[:,tf.newaxis,:] for e in E_net]) # 维度 (bs,n_expert,n_dims) # 构建多个门网络 gate_net = [gate(x) for gate in self.gate_layers] # 维度 n_task个(bs,n_expert) # towers计算:对应的门网络乘上所有的专家网络 towers = [] for i in range(self.n_task): g = tf.expand_dims(gate_net[i],axis = -1) # 维度(bs,n_expert,1) _tower = tf.matmul(E_net, g, transpose_a=True) towers.append(Flatten()(_tower)) # 维度(bs,expert_dim) return towers x = tf.ones([3,5]) # 前向转播 m = MMoE_Layer(20, 5, 2) m(x)
返回每个任务的表征:
时刻记着自己要成为什么样的人!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)