Python对列表或者字典进行复制

  这是在阅读nn.DataParallel源码的时候看到的一个模块,可以复制你所传入的列表,或者字典,或者元组,代码逻辑涉及到递归和zip的使用,我着实没看懂,但是很是优雅,记录一下。

def copy_obj(input,number):
    targets_list = range(number)
    def scatter_map(obj):
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return [list(i) for i in zip(*map(scatter_map, obj))]
        if isinstance(obj, dict) and len(obj) > 0:
            return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
        return [obj for _ in targets_list]
    return scatter_map(input)

  其中最重要的函数是scatter_map(不要纠结名称),通过递归的方式生成和原来一摸一样的对象。最后copy_obj返回一个列表,其中每个元素就是被复制的对象,下面看看例子。

例子一:

input = {'键1':112,'键2':'12'}
print(copy_obj(input, 4)) # 把input复制4次

输出:

例子二:

input = ([1,2,3],"元组")
print(copy_obj(input, 4)) # 把input复制4次

输出:

 

 

  

posted @ 2021-12-25 20:05  Circle_Wang  阅读(169)  评论(0编辑  收藏  举报