21 PyTorch 可复现设置
PyTorch 可复现设置
参考链接:
https://www.jianshu.com/p/b95ec7351603
影响模型可复现性有以下几个因素:
1 随机种子
2 Dataloader
3 不确定性的算法
具体的看上面的链接,简单来说,加上下面这两段就ok了:
def set_seed(seed):
# 随机种子
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
# 不确定性算法
# torch.set_deterministic(True) # 会报错,所以注释掉
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
def worker_init(self, worked_id):
# dataloader
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
train_loader = DataLoader(xxx, num_workers=0, worker_init_fn=self.worker_init)