nn.CrossEntropyLoss()使用是label参数的注意点

遇到个离谱的事情,自定义数据集跑cross entropy loss的时候,

    loss1 = w_loss.loss(log_ps1, source_batch_labels)
    loss1.backward()

backward()这里总是报错,搞了半天最后发现是数据集设定的时候,给labels是int32,但是实际上得设置成int64

#  toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float64)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)

# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float64)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)

label部分这样设置,CrossEntropyLoss()就不报错了

然就又报错了,原来data只能是float1,不能是double,也就是说data只能是float32不能是float64

#  toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float32)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)

# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float32)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)

这样就对了

posted @ 2022-06-27 01:36  TR_Goldfish  阅读(508)  评论(0编辑  收藏  举报