Python对列表或者字典进行复制
这是在阅读nn.DataParallel源码的时候看到的一个模块,可以复制你所传入的列表,或者字典,或者元组,代码逻辑涉及到递归和zip的使用,我着实没看懂,但是很是优雅,记录一下。
def copy_obj(input,number):
targets_list = range(number)
def scatter_map(obj):
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in targets_list]
return scatter_map(input)
其中最重要的函数是scatter_map(不要纠结名称),通过递归的方式生成和原来一摸一样的对象。最后copy_obj返回一个列表,其中每个元素就是被复制的对象,下面看看例子。
例子一:
input = {'键1':112,'键2':'12'} print(copy_obj(input, 4)) # 把input复制4次
输出:
例子二:
input = ([1,2,3],"元组") print(copy_obj(input, 4)) # 把input复制4次
输出:
以上内容如有错误,恳请指正