AAE对抗自编码器源码运行警告:File "create_datasets.py", line 31, in <module> AttributeError: can't set attribute

按照README的说明,先创建数据库吧。在pycharm上run一下create_datasets.py  之前11月份好像跑过一次完整源码,这次换个电脑重新跑,结果出现警告,一般来说警告都可以忽略不记,但是这个警告直接把我卡停,所以只能看看咋回事吧。

重点:在第34行导致的错误:trainset_new.train_data = train_data_sub.clone()

为什么呢?在mnist文件中的源码为:

 1     @property
 2     def train_labels(self):
 3         warnings.warn("train_labels has been renamed targets")
 4         return self.targets
 5 
 6     @property
 7     def test_labels(self):
 8         warnings.warn("test_labels has been renamed targets")
 9         return self.targets
10 
11     @property
12     def train_data(self):
13         warnings.warn("train_data has been renamed data")
14         return self.data
15 
16     @property
17     def test_data(self):
18         warnings.warn("test_data has been renamed data")
19         return self.data

 

也就是说不管咱调用的不管是test_data还是train_data,同一更名为data,label同理,换为targets  原来名字又臭又长所以换个简单的名字

因此,解决方法是:将原文中的所有test_data和train_data替换为data,train_labels和test_labels替换为targets。

替换后的源码如下:

from __future__ import print_function
import pickle
import numpy as np
import torch
from torchvision import datasets, transforms

from sub import subMNIST

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

trainset_original = datasets.MNIST('../data', train=True, download=True,
                                   transform=transform)

train_label_index = []
valid_label_index = []
for i in range(10):
    train_label_list = trainset_original.targets.numpy()
    # train_label_list = trainset_original.train_labels.numpy()
    label_index = np.where(train_label_list == i)[0]
    label_subindex = list(label_index[:300])
    valid_subindex = list(label_index[300: 1000 + 300])
    train_label_index += label_subindex
    valid_label_index += valid_subindex

trainset_np = trainset_original.data.numpy()
trainset_label_np = trainset_original.targets.numpy()
# trainset_np = trainset_original.train_data.numpy()
# trainset_label_np = trainset_original.train_labels.numpy()

train_data_sub = torch.from_numpy(trainset_np[train_label_index])
train_labels_sub = torch.from_numpy(trainset_label_np[train_label_index])

trainset_new = subMNIST(root='./../data', train=True, download=True, transform=transform, k=3000)

trainset_new.data = train_data_sub.clone()
trainset_new.targets = train_labels_sub.clone()
# trainset_new.train_data = train_data_sub.clone()
# trainset_new.train_labels = train_labels_sub.clone()

pickle.dump(trainset_new, open("./../data/train_labeled.p", "wb"))

validset_np = trainset_original.data.numpy()
validset_label_np = trainset_original.targets.numpy()
# validset_np = trainset_original.train_data.numpy()
# validset_label_np = trainset_original.train_labels.numpy()

valid_data_sub = torch.from_numpy(validset_np[valid_label_index])
valid_labels_sub = torch.from_numpy(validset_label_np[valid_label_index])

validset = subMNIST(root='./../data', train=False, download=True, transform=transform, k=10000)

validset.data = valid_data_sub.clone()
validset.targets = valid_labels_sub.clone()
# validset.test_data = valid_data_sub.clone()
# validset.test_labels = valid_labels_sub.clone()

pickle.dump(validset, open("./../data/validation.p", "wb"))

train_unlabel_index = []
for i in range(60000):
    if i in train_label_index or i in valid_label_index:
        pass
    else:
        train_unlabel_index.append(i)

trainset_np = trainset_original.data.numpy()
trainset_label_np = trainset_original.targets.numpy()
# trainset_np = trainset_original.train_data.numpy()
# trainset_label_np = trainset_original.train_labels.numpy()
train_data_sub_unl = torch.from_numpy(trainset_np[train_unlabel_index])
train_labels_sub_unl = torch.from_numpy(trainset_label_np[train_unlabel_index])

trainset_new_unl = subMNIST(root='./../data', train=True, download=True, transform=transform, k=47000)

trainset_new_unl.data = train_data_sub_unl.clone()
trainset_new_unl.targets = None      # Unlabeled
# trainset_new_unl.train_data = train_data_sub_unl.clone()
# trainset_new_unl.train_labels = None      # Unlabeled

trainset_new_unl.targets
# trainset_new_unl.train_labels

pickle.dump(trainset_new_unl, open("./../data/train_unlabeled.p", "wb"))
View Code

 

替换后重新运行源码,流畅跑完。数据库创建成功。

 

posted @ 2021-01-04 21:08  achived  阅读(485)  评论(0编辑  收藏  举报