代码实现-小样本-RN-问题合集
此篇为《Learning to Compare Relation Network for Few-Shot Learning》
1.结构
只实现了基于Omniglot数据集的小样本代码
datas为数据集
models为训练好的模型
venv为配置文件
下面的py文件是具体实现代码
2.问题:KeyError: '..\datas\omniglot_resized'
报错信息:
File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>
self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'
由于linux和window路径的转换,需要把把'/'改成'\'即可。
修改一:
def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])
修改二:
def omniglot_character_folders():
data_folder = '.\\datas\\omniglot_resized\\'
3.问题:IndexError: invalid index of a 0-dim tensor.
报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
按要求改成
if (episode + 1) % 100 == 0:
print("episode:", episode + 1, "loss", loss.item())
4.问题:RuntimeError: output with shape [1, 28, 28]
报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__
image = self.transform(image)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
img = t(img)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这个是使用Omniglot数据集时的报错,主要原因在于使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28],只需要把
normalize =transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成
normalize = transforms.Normalize(mean=[0.92206], std=[0.08426])
dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation)
5.问题:AttributeError: module 'torch. nn’has no attribute
报错信息:
Traceback (most recent call last):
File"D:/Omnight/omniglot/omniglot_train few shot.py",line 263, in <modulemainO
File"D:/Omnight/omniglot/omniglot train few shot.py",line 140,in main
if mn.path. exists(str("./models omniglot_feature_encoder_"+ str(CLASS _AM)way_"+ str(SANPLE_ANM_PER_CLASS) +'shot. phk1')):AttributeError: module 'torch. nn’has no attribute 'path'
torch.nn模块是PyTorch中用于神经网络构建和操作的核心模块,它包含了各种层、损失函数和激活函数等。并不包含文件或目录的处理函数,所以没有path函数。因此,更改为使用os.path.exists是一个正确的解决方法,以检查文件或目录是否存在。所以将nn改成os就行了
6.问题:IndexErrorscatter_(: Expected dtype int64 for index.
报错信息
"" vpTraceback (most recent call last) :
File "D;:/Omnight/omniglot/omniglot_train_few shot.py",line 264, in <module>
main()
File"D:/Omnight/omniglot/omniglot_train_few_shot.py",line 188, in main
one_hot_labels = Variable(torch. zeros (BATCA_NMM_PER_CLASS*CLASS_NM,CLASS _NUM .scatter_(l, batch_labels.vier-1,1),..ua(lFb)IndexErrorscatter_(: Expected dtype int64 for index.
scatter_()函数内的索引错误,此函数内部的参数必须是64位整数,改为
one_hot_labels = Variable(
torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1).long(), 1)).cuda(GPU)
loss = mse(relations, one_hot_labels)
7.问题:NotADirectoryBrror:目录名称无效。:\ .DS_Store
报错信息:
NotADirectoryBrror:[WinError 267]目录名称无效。:'.\ |datas\ omniglot_resized Alphabet_of_the_Magi\ .DS_Store
只要删除这个文件就行。
解释:
.DS_Store是Mac OS系统在文件夹中生成的隐藏文件,是特定于 Mac OS 系统的文件,是用于存储文件夹的元数据和自定义显示属性。如果出现与 .DS_Store 相关的报错,可能是因为程序在处理文件夹时意外地尝试读取或操作了 .DS_Store 文件,这可能是由于编程代码中没有正确处理隐藏文件的情况,或者没有明确地指定忽略这些文件。在数据集文件夹中删除这个文件就行。
以下问题 本次没有遇到,但是以后可能会遇到,but希望以后不会遇到很大的bug,遇到bug不可怕,一定要能解决
8.问题:'cp' + os.system
报错信息:
/LearningToCompare_FSL-master/datas/miniImagenet/proc_images.py
'cp' �����ڲ����ⲿ���Ҳ���ǿ����еij������������ļ���
用procs_images.py处理 miniImangenet 数据集的时候会报错,因为这个‘cp’是linux环境运行的。用windows系统的话要改成:
os.rename('images/' + image_name, cur_dir + image_name)
除此之外,所有的 os.system('mkdir ' + filename)
也要改成 os.mkdir(filename),虽然不一定会报错。
os.mkdir(filename)
9.问题:UserWarning: nn.functional.sigmoid is deprecated.
报错信息
UserWarning : torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
修改一:
torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
修改二:
out = F.torch.sigmoid(self.fc2(out))
10.问题:在使用Pytorch加载模型
在使用Pytorch加载模型时出现这个错误语句。
原因:原本的模型是用两个GPU训练的,而你的电脑只有一个,所以会出错。
解决:
model = torch.load(model_path)
改为:
model = torch.load(model_path, map_location='cuda:0')
如果是4块到两块:就把map_location改为:map_location={'cuda:1': 'cuda:0'}
model = torch.load(model_path, map_location='cuda:0')
11.问题:python把路径中反斜杠''变为'/'
windows文件的路径是按反斜杠’'分开的
例如:C:\ProgramData\Microsoft\Windows\Start Menu\Programs\Xmanager 5
linux总文件路径是使用”/’分开。
例如:/home/username/anaconda3/envs/tensorflow/lib/python3.6/
反斜杠’\‘的路径,linux中无法识别
需要把反斜杠"",转为正斜杠“/”
python 中 字符串的replace方法进行替换
windows_path='C:\ProgramData\Microsoft\Windows\Start Menu\Programs\Xmanager 5'
linux_path=windows_path.replace('\\','/')
#'C:/ProgramData/Microsoft/Windows/Start Menu/Programs/Xmanager 5'