AAE对抗自编码器源码运行警告:File "create_datasets.py", line 31, in <module> AttributeError: can't set attribute
按照README的说明,先创建数据库吧。在pycharm上run一下create_datasets.py 之前11月份好像跑过一次完整源码,这次换个电脑重新跑,结果出现警告,一般来说警告都可以忽略不记,但是这个警告直接把我卡停,所以只能看看咋回事吧。
重点:在第34行导致的错误:trainset_new.train_data = train_data_sub.clone()
为什么呢?在mnist文件中的源码为:
1 @property
2 def train_labels(self):
3 warnings.warn("train_labels has been renamed targets")
4 return self.targets
5
6 @property
7 def test_labels(self):
8 warnings.warn("test_labels has been renamed targets")
9 return self.targets
10
11 @property
12 def train_data(self):
13 warnings.warn("train_data has been renamed data")
14 return self.data
15
16 @property
17 def test_data(self):
18 warnings.warn("test_data has been renamed data")
19 return self.data
也就是说不管咱调用的不管是test_data还是train_data,同一更名为data,label同理,换为targets 原来名字又臭又长所以换个简单的名字
因此,解决方法是:将原文中的所有test_data和train_data替换为data,train_labels和test_labels替换为targets。
替换后的源码如下:
from __future__ import print_function
import pickle
import numpy as np
import torch
from torchvision import datasets, transforms
from sub import subMNIST
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
trainset_original = datasets.MNIST('../data', train=True, download=True,
transform=transform)
train_label_index = []
valid_label_index = []
for i in range(10):
train_label_list = trainset_original.targets.numpy()
# train_label_list = trainset_original.train_labels.numpy()
label_index = np.where(train_label_list == i)[0]
label_subindex = list(label_index[:300])
valid_subindex = list(label_index[300: 1000 + 300])
train_label_index += label_subindex
valid_label_index += valid_subindex
trainset_np = trainset_original.data.numpy()
trainset_label_np = trainset_original.targets.numpy()
# trainset_np = trainset_original.train_data.numpy()
# trainset_label_np = trainset_original.train_labels.numpy()
train_data_sub = torch.from_numpy(trainset_np[train_label_index])
train_labels_sub = torch.from_numpy(trainset_label_np[train_label_index])
trainset_new = subMNIST(root='./../data', train=True, download=True, transform=transform, k=3000)
trainset_new.data = train_data_sub.clone()
trainset_new.targets = train_labels_sub.clone()
# trainset_new.train_data = train_data_sub.clone()
# trainset_new.train_labels = train_labels_sub.clone()
pickle.dump(trainset_new, open("./../data/train_labeled.p", "wb"))
validset_np = trainset_original.data.numpy()
validset_label_np = trainset_original.targets.numpy()
# validset_np = trainset_original.train_data.numpy()
# validset_label_np = trainset_original.train_labels.numpy()
valid_data_sub = torch.from_numpy(validset_np[valid_label_index])
valid_labels_sub = torch.from_numpy(validset_label_np[valid_label_index])
validset = subMNIST(root='./../data', train=False, download=True, transform=transform, k=10000)
validset.data = valid_data_sub.clone()
validset.targets = valid_labels_sub.clone()
# validset.test_data = valid_data_sub.clone()
# validset.test_labels = valid_labels_sub.clone()
pickle.dump(validset, open("./../data/validation.p", "wb"))
train_unlabel_index = []
for i in range(60000):
if i in train_label_index or i in valid_label_index:
pass
else:
train_unlabel_index.append(i)
trainset_np = trainset_original.data.numpy()
trainset_label_np = trainset_original.targets.numpy()
# trainset_np = trainset_original.train_data.numpy()
# trainset_label_np = trainset_original.train_labels.numpy()
train_data_sub_unl = torch.from_numpy(trainset_np[train_unlabel_index])
train_labels_sub_unl = torch.from_numpy(trainset_label_np[train_unlabel_index])
trainset_new_unl = subMNIST(root='./../data', train=True, download=True, transform=transform, k=47000)
trainset_new_unl.data = train_data_sub_unl.clone()
trainset_new_unl.targets = None # Unlabeled
# trainset_new_unl.train_data = train_data_sub_unl.clone()
# trainset_new_unl.train_labels = None # Unlabeled
trainset_new_unl.targets
# trainset_new_unl.train_labels
pickle.dump(trainset_new_unl, open("./../data/train_unlabeled.p", "wb"))
替换后重新运行源码,流畅跑完。数据库创建成功。