带你从零掌握迭代器及构建最简DataLoader
0 摘要
本文本意是写 pytorch 中 DataLoader 源码学习心得,但是发现自己对迭代器和生成器的掌握比较水,不够牢固,而我也没有搜到能够解决我所有疑问的解答文章,因此诞生了这篇文章。通过本文你将能够零基础深入掌握 python 迭代器相关知识、并且能够一步步理解 DataLoader 的实现原理以及背后涉及的设计模式。
本文最终目的是通过源码学习自己实现一个功能比较完善的 DataLoader 类,为了达到这个目的,本文写作流程是:
-
先深入浅出分析 python 中迭代器、生成器等实现原理,包括 Iterable、Iterator、for .. in ..、__getitem__、yield 生成器 5个部分
-
再实现了一个最简单版本的 DataLoader,目的是理解 DataLoader 与 Dataset、Sampler、BatchSampler和 collate_fn 之间的调用关系
-
最后对该实现进行深入全面分析,读者可以清晰的理解每个类的作用
但是 DataLoader 功能其实非常复杂,故本文属于系列文章的第一篇,后面文章会不断完善、调整,最终实现 DataLoader 所有功能。或者说本文是后续文章的基础,如果基础内容没有理解非常透彻,后面的多进程、分布式版本就更难以理解了。
虽然本文比较简单,但是由于涉及到代码,故为了方便,有必要的读者可以 clone rep 进行学习(需要特意说明的是:rep 里面代码是学习目的的,质量不高,不要要求那么多)
github: https://github.com/hhaAndroid/miniloader
由于本人水平有限,某些环节理解可能有偏颇,欢迎指正。手机对于代码显示效果不太好,建议电脑端阅读。
1 python 迭代器深入浅出理解
1.1 可迭代对象 Iterable
可迭代对象 Iterable:表示该对象可迭代,其并不是指某种具体数据类型。简单来说只要是实现了 `__iter__` 方法的类就是可迭代对象。
from collections.abc import Iterable, Iterator
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __iter__(self):
# 此处返回啥无所谓
return self.a
cls_a = A()
# True
print(isinstance(cls_a, Iterable))
但是对象如果是 Iterable 的,看起来好像也没有特别大的用途,因为你依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口:
for a in cls_a:
print(a)
# 程序报错,要理解这个错误的含义
TypeError: iter() returned non-iterator of type 'list'
我们可以检查下 Iterable 接口:
class Iterable(metaclass=ABCMeta):
# 如果实现了这个方法,那么就是 Iterable
def __iter__(self):
while False:
yield None
def __subclasshook__(cls, C):
if cls is Iterable:
return _check_methods(C, "__iter__")
return NotImplemented
看起来实现 Iterable 接口用途不大,其实不是的,其有很多用途的,例如简化代码等,在后面的高级语法糖中会频繁用到,后面会分析。
1.2 迭代器 Iterator
迭代器 Iterator:其和 Iterable 之间是一个包含与被包含的关系,如果一个对象是迭代器 Iterator,那么这个对象肯定是可迭代 Iterable;但是反过来,如果一个对象是可迭代 Iterable,那么这个对象不一定是迭代器 Iterator,可以通过接口协议看出:
class Iterator(Iterable):
# 迭代具体实现
def __next__(self):
'Return the next item from the iterator. When exhausted, raise StopIteration'
raise StopIteration
# 返回自身,因为自身有 __next__ 方法(如果自身没有 __next__,那么返回自身没有意义)
def __iter__(self):
return self
def __subclasshook__(cls, C):
if cls is Iterator:
return _check_methods(C, '__iter__', '__next__')
return NotImplemented
可以发现:实现了 `__next__` 和 `__iter__` 方法的类才能称为迭代器,就可以被 for 遍历了。
class A(object):
def __init__(self):
self.index = -1
self.a = [1, 2, 3]
#必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
#因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
def __iter__(self):
return self
def __next__(self):
self.index += 1
if self.index < len(self.a):
return self.a[self.index]
else:
#抛异常,for 内部会自动捕获,表示迭代完成
raise StopIteration("遍历完了")
cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
print(isinstance(iter(cls_a), Iterator)) # True
for a in cls_a:
print(a)
# 打印 1 2 3
再次明确,一个对象如果要是 Iterator ,那么必须要实现 `__next__` 和 `__iter__` 方法,但是要理解其内部迭代流程,还需要理解 for .. in .. 流程。
1.3 for .. in .. 本质流程
for .. in .. 也就是常见的迭代操作了,其被 python 编译器编译后,实际上代码是:
# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
cls_a = iter(cls_a)
while True:
try:
# 然后调用对象的 __next__ 方法,不断返回元素
value = next(cls_a)
print(value)
# 如果迭代完成,则捕获异常即可
except StopIteration:
break
可以看出,任何一个对象如果要能够被 for 遍历,必须要实现 `__iter__` 和 `__next__` 方法,缺一不可。
明白了上述流程,那么迭代器对象 A,我们可以采用如下方式进行遍历:
myiter = iter(cls_a)
print(next(myiter))
print(next(myiter))
print(next(myiter))
# 因为遍历完了,故此时会出现错误:StopIteration: 遍历完了
print(next(myiter))
我们再来思考 python 内置对象 list 为啥可以被迭代?
b=list([1,2,3])
print(isinstance(b, Iterable)) # True
print(isinstance(b, Iterator)) # False
可以发现 list 类型是可迭代对象,但是其不是迭代器(即 list 没有 `__next__` 方法),那为啥 for .. in .. 可以迭代呢?
原因是 list 内部的 `__iter__` 方法内部返回了具备 `__next__` 方法的类,或者说调用 iter() 后返回的对象本身就是一个迭代器,当然可以 for 循环了。
b=list([1,2,3])
print(dir(b)) # 可以发现其存在 __iter__ 方法,不存在 __next__
b=iter(b) # 调用 list 内部的 __iter__,返回了具备 __next__ 的对象
print(isinstance(b, Iterable)) # True
print(isinstance(b, Iterator)) # True
print(dir(b)) # 同时具备 __iter__ 和 __next__ 方法
基于上述理解我们可以对 A 类代码进行改造,使其更加简单:
class A(object):
def __init__(self):
self.a = [1, 2, 3]
# 我们内部又调用了 list 对象的 __iter__ 方法,故此时返回的对象是迭代器对象
def __iter__(self):
return iter(self.a)
cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # False
for a in cls_a:
print(a)
# 输出:1 2 3
此时我们就实现了仅仅实现 Iterable 规范接口,但是又具备了 for .. in .. 功能,代码是不是比最开始的实现简单很多?这种写法应用也非常广泛,因为其不需要自己再次实现 `__next__` 方法。
如果你想理解的更加透彻,那么可以看下面例子:
# 仅仅实现 __iter__
class A(object):
def __init__(self):
self.b = B()
def __iter__(self):
return self.b
# 仅仅实现 __next__
class B(object):
def __init__(self):
self.index = -1
self.a = [1, 2, 3]
def __next__(self):
self.index += 1
if self.index < len(self.a):
return self.a[self.index]
else:
# 内部会自动捕获,表示迭代完成
raise StopIteration("遍历完了")
cls_a = A()
cls_b = B()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # False
print(isinstance(cls_b, Iterable)) # False
print(isinstance(cls_b, Iterator)) # False
print(type(iter(cls_a))) # B 对象
print(isinstance(iter(cls_a), Iterator)) # False
for a in cls_a:
print(a)
# 输出:1 2 3
自此我们知道了:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 `__iter__` 和 `__next__` 方法(对象必然是 Iterator),还是只实现 `__iter__`(不是 Iterator),但是内部间接返回了具备 `__next__` 对象的类,都是可行的。
但是除了这两种实现,还有其他高级语法糖,可以进一步精简代码。
1.4 __ getitem__ 理解
上面说过 for .. in .. 的本质就是调用对象的 `__iter__` 和 `__next__` 方法,但是有一种更加简单的写法,你通过仅仅实现 `__getitem__` 方法就可以让对象实现迭代功能。实际上任何一个类,如果实现了`__getitem__` 方法,那么当调用 iter(类实例) 时候会自动具备`__iter__` 和 `__next__`方法,从而可迭代了。
通过下面例子可以看出,`__getitem__` 实际上是属于 __iter__` 和 `__next__` 方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __getitem__(self, item):
return self.a[item]
cls_a = A()
print(isinstance(cls_a, Iterable)) # False
print(isinstance(cls_a, Iterator)) # False
print(dir(cls_a)) # 仅仅具备 __getitem__ 方法
cls_a = iter(cls_a)
print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
# 等价于 for .. in ..
while True:
try:
# 然后调用对象的 __next__ 方法,不断返回元素
value = next(cls_a)
print(value)
# 如果迭代完成,则捕获异常即可
except StopIteration:
break
# 输出:1 2 3
而且 `__getitem__` 还可以通过索引直接访问元素,非常方便
a[0] # 1
a[4] # 错误,索引越界
如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 `__len__` 方法即可
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __getitem__(self, item):
return self.a[item]
def __len__(self):
return len(self.a)
cls_a = A()
print(len(cls_a)) # 3
到目前为止,我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。
1.5 yield 生成器
生成器是一个在行为上和迭代器非常类似的对象,二者功能上差不多,但是生成器更优雅,只需要用关键字 yield 来返回,作用于函数上叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质就是迭代器,其最大特点是代码简洁。
def func():
for a in [1, 2, 3]:
yield a
cls_g = func()
print(isinstance(cls_g, Iterator)) # True
print(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
for a in cls_g:
print(a)
# 输出: 1 2 3
# 一种更简单的写法是用 ()
cls_g = (i for i in [1,2,3])
直观感觉和 `__getitem__` 一样,也是高级语法糖,但是比 `__getitem__` 更加简单,更加好用。
使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。假设创建一个包含10万个元素的列表,如果用 list 返回不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了,这种场景就适合采用生成器,在迭代过程中推算出后续元素,而不需要一次性全部算出。
1.6 小结
-
list set dict等内置对象都是容器 container 对象,容器是一种把多个元素组织在一起的数据结构,可以逐个迭代获取其中的元素。容器可以用 in 来判断容器中是否包含某个元素。大多数容器都是可迭代对象,可以使用某种方式访问容器中的每一个元素。
-
在迭代对象基础上,如果实现了 `__next__` 方法则是迭代器对象,该对象在调用 next() 的时候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。
-
对于采用语法糖 `__getitem__` 实现的迭代器对象,其本身实例既不是可迭代对象,更不是迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。
-
生成器是一种特殊迭代器,但是不需要像迭代器一样实现`__iter__`和`__next__`方法,只需要使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。
-
对于在类的 `__iter__` 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。
2 DataLoader 最简版本 V1
这里说的最简版本是指:没有任何花哨、高级实现技巧,仅仅以实现最基础功能为目的。具体来说是包括必备的5个对象:Dataset、Sampler、BatchSampler、DataLoader 和 collate_fn。其作用可以简要描述为如下:
-
Dataset 提供整个数据集的随机访问功能,每次调用都返回单个对象,例如一张图片和对应 target 等等
-
Sampler 提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引,常用子类是 SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引
-
BatchSampler 内部调用 Sampler 实例,输出指定 `batch_size` 个索引,然后将索引作用于 Dataset 上从而输出 `batch_size` 个数据对象,例如 batch 张图片和 batch 个 target
-
collate_fn 用于将 batch 个数据对象在 batch 维度进行聚合,生成 (b,...) 格式的数据输出,如果待聚合对象是 numpy,则会自动转化为 tensor,此时就可以输入到网络中了
迭代一次伪代码如下(非迭代器版本):
class DataLoader(object):
def __init__(self):
# 假设数据长度是100,batch_size 是4
self.dataset = [[img0, target0], [img1, target1], ..., [img99, target99]]
# 假设 sampler 是 SequentialSampler,那么实际上就是 [0,1,...,99] 列表而已
# 如果 sampler 是 RandomSampler,那么可能是 [30,1,34,2,6,...,0] 列表
self.sampler = [0, 1, 2, 3, 4, ..., 99]
self.batch_size = 4
self.index = 0
def collate_fn(self, data):
# batch 维度聚合数据
batch_img = torch.stack(data[0], 0)
batch_target = torch.stack(data[1], 0)
return batch_img, batch_target
def __next__(self):
# 0.batch_index 输出,实际上就是 BatchSampler 做的事情
i = 0
batch_index = []
while i < self.batch_size:
# 内部会调用 sampler 对象取单个索引
batch_index.append(self.sampler[self.index])
self.index += 1
i += 1
# 1.得到 batch 个数据了,调用 dataset 对象
data = [self.dataset[idx] for idx in batch_index]
# 2. 调用 collate_fn 在 batch 维度拼接输出
batch_data = self.collate_fn(data)
return batch_data
def __iter__(self):
return self
以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。
2.1 整体对象理解
首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器,你必须要先理解第一小节内容,否则本节内容会比较难理解,具体为:
-
Dataset 通过实现 `__getitem__` 方法使其可迭代
-
Sampler 对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 `__iter__` 内部返回迭代器,RandomSampler 在 `__iter__` 内部通过 yield 关键字返回迭代器
-
BatchSampler 也是在 `__iter__` 内部通过 yield 关键字返回迭代器
-
DataLoader 通过直接实现 `__next__` 和 `__iter__` 变成迭代器
注意除了 DataLoader 本身是迭代器外,其余对象本身不是迭代器,但是都能被 for .. in .. 迭代。下面一个简单例子证明:
from simplev1_datatset import SimpleV1Dataset
from libv1 import SequentialSampler, RandomSampler
from collections import Iterator, Iterable
simple_dataset = SimpleV1Dataset()
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
print(isinstance(simple_dataset, Iterable)) # False
print(isinstance(simple_dataset, Iterator)) # False
print(isinstance(iter(simple_dataset), Iterator)) # True
print(isinstance(SequentialSampler(simple_dataset), Iterable)) # True
print(isinstance(SequentialSampler(simple_dataset), Iterator)) # False
print(isinstance(iter(SequentialSampler(simple_dataset)), Iterator)) # True
# BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样
print(isinstance(RandomSampler(simple_dataset), Iterable)) # True
print(isinstance(RandomSampler(simple_dataset), Iterator)) # False
print(isinstance(iter(RandomSampler(simple_dataset)), Iterator)) # True
print(isinstance(dataloader, Iterator)) # True
在 DataLoader 中主要涉及3个类,其内部实例传递关系如下:
由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。
需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。
2.2 DataLoader 运行流程
最简单版本 DataLoader,具备如下功能:
-
Dataset 内部返回需要是 numpy 或者 tensor 对象
-
Sampler 直接 SequentialSampler 和 RandomSampler
-
BatchSampler 已经实现
-
collate_fn 仅仅考虑了 numpy 或者 tensor 对象
-
仅仅支持 num_works=0 即单进程
看起来功能非常单一,但是其实已经搭建起了整个框架,理解了这个最简框架才能去理解高级实现,其核心运行逻辑为:
def __next__(self):
# 返回 batch 个索引
index = next(self.batch_sampler)
# 利用索引去取数据
data = [self.dataset[idx] for idx in index
# batch 维度聚合
data = self.collate_fn(data)
return data
然后为了方便大家理解,特意绘制了如下代码运行流程图:
还是那句话:一定要对第1小节内容非常熟悉,否则里面这么多迭代器、生成器的调用,可能会把你绕晕。详细代码描述如下:
-
`self.batch_sampler = iter(batch_sampler)`。在 DataLoader 的类初始化,需要得到 BatchSampler 的迭代器对象
-
`index = next(self.batch_sampler)`。对于每次迭代,DataLoader 对象首先会调用 BatchSampler 的迭代器进行下一次迭代,具体是调用 BatchSampler 对象的 `__iter__` 方法
-
而 BatchSampler 对象的 `__iter__` 方法实际上是需要依靠 Sampler 对象进行迭代输出索引,Sampler 对象也是一个迭代器,当迭代 `batch_size` 次后就可以得到 `batch_size` 个数据索引
-
`data = [self.dataset[idx] for idx in index]`。有了 batch 个索引就可以通过不断调用 dataset 的 `__getitem__` 方法返回数据对象,此时 data 就包含了 batch 个对象
-
`data = self.collate_fn(data)`。将 batch 个对象输入给聚合函数,在第0个维度也就是 batch 维度进行聚合,得到类似 (b,...) 的对象
-
不断重复1-5步,就可以不断的输出一个一个 batch 的数据了
以上就是完整流程,如果理解有困难,你可以先看下一小结的代码实现,然后再返回去理解。
2.3 最简V1版本源代码
(1) Dataset
class Dataset(object):
# 只要实现了 __getitem__ 方法就可以变成迭代器
def __getitem__(self, index):
raise NotImplementedError
# 用于获取数据集长度
def __len__(self):
raise NotImplementedError
(2) Sampler
class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
def __init__(self, data_source):
super(SequentialSampler, self).__init__(data_source)
self.data_source = data_source
def __iter__(self):
# 返回迭代器,不然无法 for .. in ..
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class BatchSampler(Sampler):
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
(3) DataLoader
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, collate_fn=None, drop_last=False)
self.dataset = dataset
# 因为这两个功能是冲突的,假设 shuffle=True,
# 但是 sampler 里面是 SequentialSampler,那么就违背设计思想了
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
if batch_sampler is not None:
# 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
# 和 drop_last 四个参数就不能传入
# 因为这4个参数功能和 batch_sampler 功能冲突了
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
batch_size = None
drop_last = False
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
# 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
if batch_sampler is None:
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = iter(batch_sampler)
if collate_fn is None:
collate_fn = default_collate
self.collate_fn = collate_fn
# 核心代码
def __next__(self):
index = next(self.batch_sampler)
data = [self.dataset[idx] for idx in index]
data = self.collate_fn(data)
return data
# 返回自身,因为自身实现了 __next__
def __iter__(self):
return self
(4) collate_fn
def default_collate(batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
return torch.stack(batch, 0)
elif elem_type.__module__ == 'numpy':
return default_collate([torch.as_tensor(b) for b in batch])
else:
raise NotImplementedError
(5) 调用完整例子
class SimpleV1Dataset(Dataset):
def __init__(self):
# 伪造数据
self.imgs = np.arange(0, 16).reshape(8, 2)
def __getitem__(self, index):
return self.imgs[index]
def __len__(self):
return self.imgs.shape[0]
from simplev1_datatset import SimpleV1Dataset
simple_dataset = SimpleV1Dataset()
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
for data in dataloader:
print(data)
3 总结
本文是最小 DataLoader 系列文章的第一篇,重点是分析了 python 中迭代器相关知识,然后构建一个最简单的 DataLoader 类,用于加深到 DataLoader 流程的理解,功能比较简单。
后面慢慢完善,希望最终能实现完整功能。
github: https://github.com/hhaAndroid/miniloader
推荐阅读