pytorch 分类问题用到的分类器(F.CROSS_ENTROPY和F.BINARY_CROSS_ENTROPY_WITH_LOGITS)

推荐参考:https://www.freesion.com/article/4488859249/

实际运用时注意:

F.binary_cross_entropy_with_logits()对应的类是torch.nn.BCEWithLogitsLoss,在使用时会自动添加sigmoid,然后计算loss。(其实就是nn.sigmoid和nn.BCELoss的合体)

total = model(xi, xv)  # 回到forward函数 , 返回 100*1维
loss = criterion(total, y)  # y是label,整型 0或1
preds = (F.sigmoid(total) > 0.5)  # 配合sigmoid使用
train_num_correct += (preds == y).sum()

 

想一想,其实交叉熵就是-sum(y_true * log(y_pred)),链接中的公式中,由于只有y_true等于1时计算才有效,所以可以化简,同时y_pred经过了softmax处理

posted @ 2020-12-08 11:13  qiezi_online  阅读(2022)  评论(0编辑  收藏  举报