pytorch CrossEntropyLoss() 默认转换one-hot编码
摘要:import torchpredict = torch.randn((4,3))# crossentropyloss不需要predict的概率为1,predict为logits# predict = torch.nn.functional.softmax(predict,dim = 1)target
阅读全文
posted @
2022-03-24 22:46
一点飞鸿
阅读(454)
推荐(0) 编辑
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
摘要:代码: cate_ids=np.unique(gt_box_array[:,-1]) for tmp_cateid in cate_ids: conf_matrix[tmp_cateid,0]+=1 原因:numpy里面不指定类型的话,默认是float64位,无法作为索引 修改:强转为int ,即c
阅读全文
posted @
2022-03-09 18:06
一点飞鸿
阅读(3427)
推荐(1) 编辑
mxnet的broadcast_power() 注释错误
摘要:用relationnet时,发现broadcast_power()的源码中的注释如下: 官方文档中的注释如下: 怎么算都算不出它这个结果。。。 自己用mxnet实验了一把,发现是注释错了,代码如下:
阅读全文
posted @
2022-03-08 09:54
一点飞鸿
阅读(47)
推荐(0) 编辑