21 PyTorch 可复现设置

PyTorch 可复现设置

参考链接:

https://www.jianshu.com/p/b95ec7351603

影响模型可复现性有以下几个因素:

1 随机种子

2 Dataloader

3 不确定性的算法

具体的看上面的链接,简单来说,加上下面这两段就ok了:

def set_seed(seed):
    # 随机种子
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.

    # 不确定性算法
    # torch.set_deterministic(True) # 会报错,所以注释掉
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False 
    torch.backends.cudnn.benchmark = False
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    os.environ['PYTHONHASHSEED'] = str(seed)
def worker_init(self, worked_id):
    # dataloader
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

train_loader = DataLoader(xxx, num_workers=0, worker_init_fn=self.worker_init)
posted @ 2022-01-04 23:25  SethDeng  阅读(180)  评论(0编辑  收藏  举报