pytorch中CrossEntropyLoss的使用

csdn
CrossEntropyLoss 等价于 softmax+log+NLLLoss

LogSoftmax等价于softmax+log

# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(input, target)

input维度为N*C,是网络生成的值,N为batch_size,C为类别数;

target维度为N,是标注值,非one-hot类型的值

posted @ 2022-02-23 00:15  zae  阅读(130)  评论(0编辑  收藏  举报