带你从零掌握迭代器及构建最简DataLoader

0 摘要

    本文本意是写 pytorch 中 DataLoader 源码学习心得,但是发现自己对迭代器和生成器的掌握比较水,不够牢固,而我也没有搜到能够解决我所有疑问的解答文章,因此诞生了这篇文章。通过本文你将能够零基础深入掌握 python 迭代器相关知识、并且能够一步步理解 DataLoader 的实现原理以及背后涉及的设计模式

    本文最终目的是通过源码学习自己实现一个功能比较完善的 DataLoader 类,为了达到这个目的,本文写作流程是:

 

  • 先深入浅出分析 python 中迭代器、生成器等实现原理,包括 Iterable、Iterator、for .. in ..、__getitem__、yield 生成器 5个部分

  • 再实现了一个最简单版本的 DataLoader,目的是理解 DataLoader 与 Dataset、Sampler、BatchSampler和 collate_fn 之间的调用关系

  • 最后对该实现进行深入全面分析,读者可以清晰的理解每个类的作用

     

    但是 DataLoader 功能其实非常复杂,故本文属于系列文章的第一篇,后面文章会不断完善、调整,最终实现 DataLoader 所有功能。或者说本文是后续文章的基础,如果基础内容没有理解非常透彻,后面的多进程、分布式版本就更难以理解了。

    虽然本文比较简单,但是由于涉及到代码,故为了方便,有必要的读者可以 clone rep 进行学习(需要特意说明的是:rep 里面代码是学习目的的,质量不高,不要要求那么多)

 

github:  https://github.com/hhaAndroid/miniloader

 

由于本人水平有限,某些环节理解可能有偏颇,欢迎指正。手机对于代码显示效果不太好,建议电脑端阅读。

 

1 python 迭代器深入浅出理解

 

1.1 可迭代对象 Iterable

    可迭代对象 Iterable:表示该对象可迭代,其并不是指某种具体数据类型。简单来说只要是实现了 `__iter__` 方法的类就是可迭代对象

from collections.abc import Iterable, Iterator
class A(object):
    def __init__(self):
        self.a = [1, 2, 3]
    def __iter__(self):
        # 此处返回啥无所谓
        return self.a
cls_a = A()
#  True
print(isinstance(cls_a, Iterable))

    但是对象如果是 Iterable 的,看起来好像也没有特别大的用途,因为你依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口:

for a in cls_a:
    print(a)

# 程序报错,要理解这个错误的含义
TypeError: iter() returned non-iterator of type 'list'

    我们可以检查下 Iterable 接口:

class Iterable(metaclass=ABCMeta):

    # 如果实现了这个方法,那么就是 Iterable
    @abstractmethod
    def __iter__(self):
        while False:
            yield None

    @classmethod
    def __subclasshook__(cls, C):
        if cls is Iterable:
            return _check_methods(C, "__iter__")
        return NotImplemented

看起来实现 Iterable 接口用途不大,其实不是的,其有很多用途的,例如简化代码等,在后面的高级语法糖中会频繁用到,后面会分析。

 

1.2 迭代器 Iterator

    迭代器 Iterator:其和 Iterable 之间是一个包含与被包含的关系,如果一个对象是迭代器 Iterator,那么这个对象肯定是可迭代 Iterable;但是反过来,如果一个对象是可迭代 Iterable,那么这个对象不一定是迭代器 Iterator,可以通过接口协议看出:

class Iterator(Iterable):

    # 迭代具体实现
    @abstractmethod
    def __next__(self):
        'Return the next item from the iterator. When exhausted, raise StopIteration'
        raise StopIteration

    # 返回自身,因为自身有 __next__ 方法(如果自身没有 __next__,那么返回自身没有意义)
    def __iter__(self):
        return self
        
    @classmethod
    def __subclasshook__(cls, C):
        if cls is Iterator:
            return _check_methods(C, '__iter__', '__next__')
        return NotImplemented

可以发现:实现了 `__next__` 和 `__iter__` 方法的类才能称为迭代器,就可以被 for 遍历了

class A(object):
    def __init__(self):
        self.index = -1
        self.a = [1, 2, 3]

    #必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
    #因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
    def __iter__(self):
        return self

    def __next__(self):
        self.index += 1
        if self.index < len(self.a):
            return self.a[self.index]
        else:
            #抛异常,for 内部会自动捕获,表示迭代完成
            raise StopIteration("遍历完了")
cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
print(isinstance(iter(cls_a), Iterator)) # True

for a in cls_a:
    print(a)
# 打印 1 2 3

再次明确,一个对象如果要是 Iterator ,那么必须要实现 `__next__` 和 `__iter__` 方法,但是要理解其内部迭代流程,还需要理解 for .. in .. 流程。

 

1.3 for .. in .. 本质流程

    for .. in .. 也就是常见的迭代操作了,其被 python 编译器编译后,实际上代码是:

# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
cls_a = iter(cls_a)
while True:
    try:
        # 然后调用对象的 __next__ 方法,不断返回元素
        value = next(cls_a)
        print(value)
    # 如果迭代完成,则捕获异常即可
    except StopIteration:
        break

 

可以看出,任何一个对象如果要能够被 for 遍历,必须要实现  `__iter__` 和 `__next__` 方法,缺一不可

    明白了上述流程,那么迭代器对象 A,我们可以采用如下方式进行遍历:

myiter = iter(cls_a)
print(next(myiter))
print(next(myiter))
print(next(myiter))
# 因为遍历完了,故此时会出现错误:StopIteration: 遍历完了
print(next(myiter))

我们再来思考 python 内置对象 list 为啥可以被迭代

b=list([1,2,3])
print(isinstance(b, Iterable)) # True
print(isinstance(b, Iterator)) # False

    可以发现 list 类型是可迭代对象,但是其不是迭代器(即 list 没有 `__next__` 方法),那为啥 for .. in .. 可以迭代呢?

    原因是 list 内部的 `__iter__` 方法内部返回了具备 `__next__` 方法的类,或者说调用 iter() 后返回的对象本身就是一个迭代器,当然可以 for 循环了

b=list([1,2,3])
print(dir(b)) # 可以发现其存在 __iter__ 方法,不存在 __next__

b=iter(b) # 调用 list 内部的 __iter__,返回了具备 __next__ 的对象
print(isinstance(b, Iterable)) # True
print(isinstance(b, Iterator)) # True
print(dir(b)) # 同时具备 __iter__ 和 __next__ 方法

基于上述理解我们可以对 A 类代码进行改造,使其更加简单:

class A(object):
    def __init__(self):
        self.a = [1, 2, 3]
    # 我们内部又调用了 list 对象的 __iter__ 方法,故此时返回的对象是迭代器对象
    def __iter__(self):
        return iter(self.a)

cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # False

for a in cls_a:
    print(a)
# 输出:1 2 3

    此时我们就实现了仅仅实现 Iterable 规范接口,但是又具备了 for .. in .. 功能,代码是不是比最开始的实现简单很多?这种写法应用也非常广泛,因为其不需要自己再次实现 `__next__` 方法。

    如果你想理解的更加透彻,那么可以看下面例子:

# 仅仅实现 __iter__ 
class A(object):
    def __init__(self):
        self.b = B()

    def __iter__(self):
        return self.b

# 仅仅实现 __next__
class B(object):
    def __init__(self):
        self.index = -1
        self.a = [1, 2, 3]

    def __next__(self):
        self.index += 1
        if self.index < len(self.a):
            return self.a[self.index]
        else:
            # 内部会自动捕获,表示迭代完成
            raise StopIteration("遍历完了")


cls_a = A()
cls_b = B()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # False
print(isinstance(cls_b, Iterable)) # False
print(isinstance(cls_b, Iterator)) # False

print(type(iter(cls_a))) # B 对象
print(isinstance(iter(cls_a), Iterator)) # False

for a in cls_a:
    print(a)

# 输出:1 2 3

    自此我们知道了:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 `__iter__` 和 `__next__` 方法(对象必然是 Iterator),还是只实现 `__iter__`(不是 Iterator),但是内部间接返回了具备 `__next__` 对象的类,都是可行的

    但是除了这两种实现,还有其他高级语法糖,可以进一步精简代码。

 

1.4  __ getitem__ 理解

    上面说过 for .. in .. 的本质就是调用对象的 `__iter__` 和 `__next__` 方法,但是有一种更加简单的写法,你通过仅仅实现 `__getitem__` 方法就可以让对象实现迭代功能。实际上任何一个类,如果实现了`__getitem__` 方法,那么当调用 iter(类实例) 时候会自动具备`__iter__` 和 `__next__`方法,从而可迭代了。

    通过下面例子可以看出,`__getitem__` 实际上是属于 __iter__` 和 `__next__` 方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。

class A(object):
    def __init__(self):
        self.a = [1, 2, 3]

    def __getitem__(self, item):
        return self.a[item]
        
cls_a = A()
print(isinstance(cls_a, Iterable))  # False
print(isinstance(cls_a, Iterator))  # False
print(dir(cls_a))  # 仅仅具备 __getitem__ 方法

cls_a = iter(cls_a)
print(dir(cls_a))  # 具备 __iter__ 和 __next__ 方法

print(isinstance(cls_a, Iterable))  # True
print(isinstance(cls_a, Iterator))  # True

# 等价于 for .. in ..
while True:
    try:
        # 然后调用对象的 __next__ 方法,不断返回元素
        value = next(cls_a)
        print(value)
    # 如果迭代完成,则捕获异常即可
    except StopIteration:
        break

# 输出:1 2 3

而且 `__getitem__` 还可以通过索引直接访问元素,非常方便

a[0] # 1  
a[4] # 错误,索引越界

如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 `__len__` 方法即可

class A(object):
    def __init__(self):
        self.a = [1, 2, 3]

    def __getitem__(self, item):
        return self.a[item]

    def __len__(self):
        return len(self.a)

cls_a = A()  
print(len(cls_a)) # 3

    到目前为止,我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

 

1.5 yield 生成器

    生成器是一个在行为上和迭代器非常类似的对象,二者功能上差不多,但是生成器更优雅,只需要用关键字 yield 来返回,作用于函数上叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质就是迭代器,其最大特点是代码简洁。

def func():
    for a in [1, 2, 3]:
        yield a

cls_g = func()
print(isinstance(cls_g, Iterator))  # True
print(dir(cls_g))  # 自动具备 __iter__ 和 __next__ 方法

for a in cls_g:
    print(a)

# 输出: 1 2 3

# 一种更简单的写法是用 ()
cls_g = (i for i in [1,2,3])

    直观感觉和 `__getitem__` 一样,也是高级语法糖,但是比 `__getitem__` 更加简单,更加好用。

    使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。假设创建一个包含10万个元素的列表,如果用 list 返回不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了,这种场景就适合采用生成器,在迭代过程中推算出后续元素,而不需要一次性全部算出。

 

1.6 小结

图片

  •  list set dict等内置对象都是容器 container 对象,容器是一种把多个元素组织在一起的数据结构,可以逐个迭代获取其中的元素。容器可以用 in 来判断容器中是否包含某个元素。大多数容器都是可迭代对象,可以使用某种方式访问容器中的每一个元素。

  • 在迭代对象基础上,如果实现了 `__next__`  方法则是迭代器对象,该对象在调用 next()  的时候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

  • 对于采用语法糖 `__getitem__` 实现的迭代器对象,其本身实例既不是可迭代对象,更不是迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

  • 生成器是一种特殊迭代器,但是不需要像迭代器一样实现`__iter__`和`__next__`方法,只需要使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。

  • 对于在类的 `__iter__` 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成迭代器。

     

2 DataLoader 最简版本 V1

    这里说的最简版本是指:没有任何花哨、高级实现技巧,仅仅以实现最基础功能为目的。具体来说是包括必备的5个对象:Dataset、Sampler、BatchSampler、DataLoader 和 collate_fn。其作用可以简要描述为如下:

 

  • Dataset 提供整个数据集的随机访问功能,每次调用都返回单个对象,例如一张图片和对应 target 等等

  • Sampler 提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引,常用子类是 SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

  • BatchSampler 内部调用 Sampler 实例,输出指定 `batch_size` 个索引,然后将索引作用于 Dataset 上从而输出 `batch_size` 个数据对象,例如 batch 张图片和 batch 个 target

  • collate_fn 用于将 batch 个数据对象在 batch 维度进行聚合,生成 (b,...) 格式的数据输出,如果待聚合对象是 numpy,则会自动转化为 tensor,此时就可以输入到网络中了

 

迭代一次伪代码如下(非迭代器版本):

class DataLoader(object):
    def __init__(self):
        # 假设数据长度是100,batch_size 是4
        self.dataset = [[img0, target0], [img1, target1], ..., [img99, target99]]
        # 假设 sampler 是 SequentialSampler,那么实际上就是 [0,1,...,99] 列表而已
        # 如果 sampler 是 RandomSampler,那么可能是 [30,1,34,2,6,...,0] 列表
        self.sampler = [0, 1, 2, 3, 4, ..., 99]
        self.batch_size = 4
        self.index = 0

    def collate_fn(self, data):
        # batch 维度聚合数据
        batch_img = torch.stack(data[0], 0)
        batch_target = torch.stack(data[1], 0)
        return batch_img, batch_target

    def __next__(self):
        # 0.batch_index 输出,实际上就是 BatchSampler 做的事情
        i = 0
        batch_index = []
        while i < self.batch_size:
            # 内部会调用 sampler 对象取单个索引
            batch_index.append(self.sampler[self.index])
            self.index += 1
            i += 1

        # 1.得到 batch 个数据了,调用 dataset 对象 
        data = [self.dataset[idx] for idx in batch_index]

        # 2. 调用 collate_fn 在 batch 维度拼接输出
        batch_data = self.collate_fn(data)
        return batch_data

    def __iter__(self):
        return self

    以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

 

2.1 整体对象理解

    首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器,你必须要先理解第一小节内容,否则本节内容会比较难理解,具体为:

 

  •  Dataset 通过实现 `__getitem__` 方法使其可迭代

  •  Sampler 对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 `__iter__` 内部返回迭代器,RandomSampler 在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  BatchSampler 也是在 `__iter__` 内部通过 yield 关键字返回迭代器

  •  DataLoader 通过直接实现 `__next__` 和 `__iter__` 变成迭代器

 

    注意除了 DataLoader 本身是迭代器外,其余对象本身不是迭代器,但是都能被 for .. in .. 迭代。下面一个简单例子证明:

from simplev1_datatset import SimpleV1Dataset
from libv1 import SequentialSampler, RandomSampler 
from collections import Iterator, Iterable   

simple_dataset = SimpleV1Dataset()   
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)

print(isinstance(simple_dataset, Iterable)) # False
print(isinstance(simple_dataset, Iterator)) # False
print(isinstance(iter(simple_dataset), Iterator)) # True

print(isinstance(SequentialSampler(simple_dataset), Iterable)) # True
print(isinstance(SequentialSampler(simple_dataset), Iterator)) # False
print(isinstance(iter(SequentialSampler(simple_dataset)), Iterator)) # True

# BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样
print(isinstance(RandomSampler(simple_dataset), Iterable)) # True
print(isinstance(RandomSampler(simple_dataset), Iterator)) # False
print(isinstance(iter(RandomSampler(simple_dataset)), Iterator)) # True

print(isinstance(dataloader, Iterator)) # True

    在 DataLoader 中主要涉及3个类,其内部实例传递关系如下:

图片

    由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

    需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

 

2.2 DataLoader 运行流程

    最简单版本 DataLoader,具备如下功能:

 

  • Dataset 内部返回需要是 numpy 或者 tensor 对象

  • Sampler 直接 SequentialSampler 和 RandomSampler

  • BatchSampler 已经实现

  • collate_fn 仅仅考虑了 numpy 或者 tensor 对象

  • 仅仅支持 num_works=0 即单进程

 

看起来功能非常单一,但是其实已经搭建起了整个框架,理解了这个最简框架才能去理解高级实现,其核心运行逻辑为:

def __next__(self):
    # 返回 batch 个索引
    index = next(self.batch_sampler)
    # 利用索引去取数据
    data = [self.dataset[idx] for idx in index
    # batch 维度聚合
    data = self.collate_fn(data)
    return data

然后为了方便大家理解,特意绘制了如下代码运行流程图:

图片

    还是那句话:一定要对第1小节内容非常熟悉,否则里面这么多迭代器、生成器的调用,可能会把你绕晕。详细代码描述如下:

 

  1. `self.batch_sampler = iter(batch_sampler)`。在 DataLoader 的类初始化,需要得到 BatchSampler 的迭代器对象

  2. `index = next(self.batch_sampler)`。对于每次迭代,DataLoader 对象首先会调用 BatchSampler 的迭代器进行下一次迭代,具体是调用 BatchSampler 对象的  `__iter__`  方法

  3. 而 BatchSampler 对象的 `__iter__` 方法实际上是需要依靠 Sampler 对象进行迭代输出索引,Sampler 对象也是一个迭代器,当迭代 `batch_size` 次后就可以得到 `batch_size` 个数据索引

  4. `data = [self.dataset[idx] for idx in index]`。有了 batch 个索引就可以通过不断调用  dataset 的 `__getitem__` 方法返回数据对象,此时 data 就包含了 batch 个对象

  5. `data = self.collate_fn(data)`。将 batch 个对象输入给聚合函数,在第0个维度也就是 batch 维度进行聚合,得到类似 (b,...) 的对象

  6. 不断重复1-5步,就可以不断的输出一个一个 batch 的数据了

 

以上就是完整流程,如果理解有困难,你可以先看下一小结的代码实现,然后再返回去理解

 

2.3 最简V1版本源代码

 

(1) Dataset

class Dataset(object):
    # 只要实现了 __getitem__ 方法就可以变成迭代器
    def __getitem__(self, index):  
        raise NotImplementedError 
    # 用于获取数据集长度 
    def __len__(self): 
        raise NotImplementedError

(2) Sampler

class Sampler(object):
    def __init__(self, data_source):
        pass  

    def __iter__(self):  
        raise NotImplementedError 

    def __len__(self):  
        raise NotImplementedError

 

class SequentialSampler(Sampler): 

    def __init__(self, data_source):  
        super(SequentialSampler, self).__init__(data_source) 
        self.data_source = data_source 

    def __iter__(self): 
        # 返回迭代器,不然无法 for .. in .. 
        return iter(range(len(self.data_source))) 

    def __len__(self): 
        return len(self.data_source)
class BatchSampler(Sampler):  
    def __init__(self, sampler, batch_size, drop_last): 
        self.sampler = sampler 
        self.batch_size = batch_size 
        self.drop_last = drop_last 

    def __iter__(self): 
        batch = [] 
        # 调用 sampler 内部的迭代器对象 
        for idx in self.sampler: 
            batch.append(idx) 
            # 如果已经得到了 batch 个 索引,则可以通过 yield 
            # 关键字生成生成器返回,得到迭代器对象 
            if len(batch) == self.batch_size: 
                yield batch     
                batch = [] 
        if len(batch) > 0 and not self.drop_last: 
            yield batch  

    def __len__(self): 
        if self.drop_last: 
            # 如果最后的索引数不够一个 batch,则抛弃 
            return len(self.sampler) // self.batch_size 
        else:  
            return  (len(self.sampler) + self.batch_size - 1// self.batch_size

(3) DataLoader

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, collate_fn=None, drop_last=False)
        self.dataset = dataset    

        # 因为这两个功能是冲突的,假设 shuffle=True,
        # 但是 sampler 里面是 SequentialSampler,那么就违背设计思想了
        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if batch_sampler is not None:
            # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
            # 和 drop_last 四个参数就不能传入
            # 因为这4个参数功能和 batch_sampler 功能冲突了
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False

        if sampler is None:
            if shuffle:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)

        # 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
        if batch_sampler is None:
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = iter(batch_sampler)

        if collate_fn is None:
            collate_fn = default_collate
        self.collate_fn = collate_fn

    # 核心代码
    def __next__(self):
        index = next(self.batch_sampler)
        data = [self.dataset[idx] for idx in index]
        data = self.collate_fn(data)
        return data

    # 返回自身,因为自身实现了 __next__ 
    def __iter__(self):
        return self

(4) collate_fn

def default_collate(batch): 
    elem = batch[0]  
    elem_type = type(elem)  
    if isinstance(elem, torch.Tensor):  
        return torch.stack(batch, 0)  
    elif elem_type.__module__ == 'numpy':  
        return default_collate([torch.as_tensor(b) for b in batch])  
    else:  
        raise NotImplementedError

(5) 调用完整例子

class SimpleV1Dataset(Dataset): 
    def __init__(self):  
        # 伪造数据  
        self.imgs = np.arange(0, 16).reshape(8, 2)  

    def __getitem__(self, index):  
        return self.imgs[index]  

    def __len__(self):  
        return self.imgs.shape[0]  


from simplev1_datatset import SimpleV1Dataset  
simple_dataset = SimpleV1Dataset()    
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
for data in dataloader:  
    print(data)

 

3 总结

    本文是最小 DataLoader 系列文章的第一篇,重点是分析了 python 中迭代器相关知识,然后构建一个最简单的 DataLoader 类,用于加深到 DataLoader 流程的理解,功能比较简单。

    后面慢慢完善,希望最终能实现完整功能。

 

github: https://github.com/hhaAndroid/miniloader

 

推荐阅读

PyTorch 源码解读之 torch.autograd

PyTorch 源码解读之 BN & SyncBN

 

posted @   水木清扬  阅读(570)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
历史上的今天:
2018-12-30 ubuntu 安装Pangolin 过程
2018-12-30 ubuntu16.04 + Kdevelop + ROS开发和创建catkin_ws工作空间
点击右上角即可分享
微信分享提示