pytorch inception v3 KeyError: <class 'tuple'>解决方法

刚入门深度学习,试着跑了一下resnet18和resnet50都没有问题,但是在运行inception v3的时候遇到一个问题怎么都解决不了

错误输出如下:

Traceback (most recent call last):
  File "transfer_learning_tutorial.py", line 277, in <module>
    num_epochs=25)
  File "transfer_learning_tutorial.py", line 179, in train_model
    _, preds = torch.max(outputs, 1)
TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, Tensor out)
 * (Tensor input, int dim, bool keepdim, tuple of Tensors out)

在github上找到答案

inception源码连接第125行,在train模式下并且aux_logits打开的情况下,返回x, aux。

所以解决方法有以下两个:

方法1.创建inception模型的时候,关闭aux_logits。设置关键字参数aux_logits=False

方法2.接收返回的aux。

output, aux = model(input_var)#注意只有在训练模式下接收两个参数
out=model(input_var)#在求值模式下,仍然只返回一个参数
                    if phase=='train':
                        outputs,aux= model(inputs)
                    else:
                        outputs=model(inputs)

 

 
posted @ 2018-09-07 13:30  MalcolmMeng  阅读(924)  评论(0编辑  收藏  举报