etf分类器
feature, logit, out = node.model(data) # PyTorch,模型的.model(data)调用通常涉及到两个主要步骤:前向传播和激活输出的计算。这里没有直接涉及到反向传播,因为反向传播是在损失函数计算之后,通过调用损失函数的.backward()方法来执行的。
output_local = torch.matmul(feature, node.model.proto_classifier.proto) # μ * V_ETF
output_local = node.model.scaling_train * output_local # 再✖️β
if opt == 'balancedloss':
loss_local = balanced_softmax_loss(output_local, target, node.sample_per_class) # output_local相当于logit
解释一下,feature是fedetf投影后的特征,目的是让ETF为n x n 的方阵,这也其实就是网络的最后一层输出了。
分类器其实就是权重:
logit = w_c*H+b
logits就是一个全连接层,得到各个类的分数(但不是0~1),所以对于etf中,要固定的就是W_c,而H就是feature,原始的logts是nn.linerline(out)