dataloader AttributeError: Can‘t pickle local object ‘trainer_synapse.<locals>.worker_init_fn‘ (不需要改numworkers)
在跑transunet代码(https://github.com/Beckschen/TransUNet,论文:https://arxiv.org/pdf/2102.04306.pdf)的时候遇到上述问题,在网上解决方法基本都是把dataloader的numworkers改为0,但改完后训练速度会下降
我的理解也比较浅薄,就是dataloader新的线程不能在dataset”trainer_synapse.py“文件下找到worker_init_fn函数
可以看到一开始的trainer函数如下:
def trainer_synapse(args, model, snapshot_path): from datasets.dataset_synapse import Synapse_dataset, RandomGenerator logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size * args.n_gpu # max_iterations = args.max_iterations db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", transform=transforms.Compose( [RandomGenerator(output_size=[args.img_size, args.img_size])])) print("The length of train set is: {}".format(len(db_train))) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn,) ...
程序将def worker_init_fn 放在了trainer_synapse里面,所以外部调用trainer_synapse的时候找不到worker_init_fn函数(还没被定义)
所以我就简单地将代码改成了:
def worker_init_fn(worker_id): random.seed(1234 + worker_id) def trainer_synapse(args, model, snapshot_path): from datasets.dataset_synapse import Synapse_dataset, RandomGenerator logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size * args.n_gpu # max_iterations = args.max_iterations db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", transform=transforms.Compose( [RandomGenerator(output_size=[args.img_size, args.img_size])])) print("The length of train set is: {}".format(len(db_train))) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn,) ...
(也就是把
def worker_init_fn
移动到了
def trainer_synapse
的外面,并且把arg.seed改成了作者默认的1234——我只是图省事,如果需要保留randomseed这个功能,可以自己重新写参数传递方法,比如设一个seedconfig.cfg文件)
总之我是成功跑起来了,还把transunet用到了自己的数据集上,课程设计有着落了,欧耶!
posted on 2021-12-09 22:34 crazyplayer 阅读(3126) 评论(4) 编辑 收藏 举报