nn.CrossEntropyLoss()使用是label参数的注意点
遇到个离谱的事情,自定义数据集跑cross entropy loss的时候,
loss1 = w_loss.loss(log_ps1, source_batch_labels)
loss1.backward()
backward()
这里总是报错,搞了半天最后发现是数据集设定的时候,给labels是int32,但是实际上得设置成int64
# toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float64)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)
# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float64)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)
label部分这样设置,CrossEntropyLoss()就不报错了
然就又报错了,原来data只能是float1,不能是double,也就是说data只能是float32不能是float64
# toy_source数据类型转化
source_datas_t = torch.tensor(source_datas, dtype=torch.float32)
source_labels_t = torch.tensor(source_labels, dtype=torch.int64)
# toy_target数据类型转化
target_datas_t = torch.tensor(target_datas, dtype=torch.float32)
target_labels_t = torch.tensor(target_labels, dtype=torch.int64)
这样就对了