dataloader AttributeError: Can‘t pickle local object ‘trainer_synapse.<locals>.worker_init_fn‘ (不需要改numworkers)

在跑transunet代码(https://github.com/Beckschen/TransUNet,论文:https://arxiv.org/pdf/2102.04306.pdf)的时候遇到上述问题,在网上解决方法基本都是把dataloader的numworkers改为0,但改完后训练速度会下降

我的理解也比较浅薄,就是dataloader新的线程不能在dataset”trainer_synapse.py“文件下找到worker_init_fn函数

可以看到一开始的trainer函数如下:

def trainer_synapse(args, model, snapshot_path):
    from datasets.dataset_synapse import Synapse_dataset, RandomGenerator
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size * args.n_gpu
    # max_iterations = args.max_iterations
    db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",
                               transform=transforms.Compose(
                                   [RandomGenerator(output_size=[args.img_size, args.img_size])]))

    print("The length of train set is: {}".format(len(db_train)))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn,)
    ...

程序将def worker_init_fn 放在了trainer_synapse里面,所以外部调用trainer_synapse的时候找不到worker_init_fn函数(还没被定义)

所以我就简单地将代码改成了:

def worker_init_fn(worker_id):
    random.seed(1234 + worker_id)
def trainer_synapse(args, model, snapshot_path):
    from datasets.dataset_synapse import Synapse_dataset, RandomGenerator
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size * args.n_gpu
    # max_iterations = args.max_iterations
    db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",
                               transform=transforms.Compose(
                                   [RandomGenerator(output_size=[args.img_size, args.img_size])]))

    print("The length of train set is: {}".format(len(db_train)))



    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn,)
    ...

(也就是把

def worker_init_fn

移动到了

def trainer_synapse

的外面,并且把arg.seed改成了作者默认的1234——我只是图省事,如果需要保留randomseed这个功能,可以自己重新写参数传递方法,比如设一个seedconfig.cfg文件)

总之我是成功跑起来了,还把transunet用到了自己的数据集上,课程设计有着落了,欧耶!

 

posted on 2021-12-09 22:34  crazyplayer  阅读(3126)  评论(4编辑  收藏  举报

导航