代码学习
看到一部分这样的代码
# build models model = build_model(args) ema_model = build_model(args, ema=True)
通过百度发现,似乎是一种叫mean teacher的半监督处理方法:(参考自)
第一个model被称为student model
第二个ema_model被称为teacher model(EMA即Exponential Moving Average,指数移动平均)
在半监督中,每个输入Batch包含一半已标注的图像与一般未标注的图像。首先,整个Batch会被送入Student Model中,得到一个预测结果。对于Batch中的已标注部分,利用结果与真值计算loss,进行梯度反向传播,从而更新Student Model的参数。而对于Batch中的未标注部分,其输入Student Model也会得到一个结果(记为A),未标注的图像在加入随机噪声后,也会被送入Teacher Model中,得到一个预测结果(记为B):
我们希望A==B,这样的话说明模型的参数比较鲁棒泛化。
在TRSSL的代码中,我想加入一个磁瓦的数据集,把路径加进去之后发现bug频出,然后发现args是init里的形参,但是他在get_data里调用了.....
lass tinyimagenet_dataset(): def __init__(self, args): # augmentations self.transform_train = transforms.Compose([ transforms.RandomChoice([ transforms.RandomCrop(64, padding=8), transforms.RandomResizedCrop(64, (0.5, 1.0)), ]), transforms.RandomHorizontalFlip(), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.5), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.2), transforms.ToTensor(), transforms.Normalize(tinyimagenet_mean, tinyimagenet_std), ]) self.transform_val = transforms.Compose([ transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize(mean=tinyimagenet_mean, std=tinyimagenet_std) ]) base_dataset = datasets.ImageFolder(os.path.join(args.data_root, 'train')) base_dataset_targets = np.array(base_dataset.imgs) base_dataset_targets = base_dataset_targets[:,1] base_dataset_targets= list(map(int, base_dataset_targets.tolist())) train_labeled_idxs, train_unlabeled_idxs = x_u_split_seen_novel(base_dataset_targets, args.lbl_percent, args.no_class, list(range(0,args.no_seen)), list(range(args.no_seen, args.no_class)), args.imb_factor) self.train_labeled_idxs = train_labeled_idxs self.train_unlabeled_idxs = train_unlabeled_idxs self.temperature = args.temperature self.data_root = args.data_root self.no_seen = args.no_seen self.no_class = args.no_class def get_dataset(self, temp_uncr=None): train_labeled_idxs = self.train_labeled_idxs.copy() train_unlabeled_idxs = self.train_unlabeled_idxs.copy() train_labeled_dataset = GenericSSL(os.path.join(args.data_root, 'train'), train_labeled_idxs, transform=self.transform_train, temperature=self.temperature) train_unlabeled_dataset = GenericSSL(os.path.join(args.data_root, 'train'), train_unlabeled_idxs, transform=TransformTwice(self.transform_train), temperature=self.temperature, temp_uncr=temp_uncr) if temp_uncr is not None: return train_labeled_dataset, train_unlabeled_dataset train_uncr_dataset = GenericUNCR(os.path.join(args.data_root, 'train'), train_unlabeled_idxs, transform=self.transform_train) test_dataset_seen = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val, labeled_set=list(range(0,args.no_seen))) test_dataset_novel = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val, labeled_set=list(range(args.no_seen, args.no_class))) test_dataset_all = GenericTEST(os.path.join(args.data_root, 'test'), no_class=args.no_class, transform=self.transform_val) return train_labeled_dataset, train_unlabeled_dataset, train_uncr_dataset, test_dataset_all, test_dataset_seen, test_dataset_novel