读python代码-学到的python函数-2-collate_fn


一段代码from 'datasets.py':

def collate_fn(batch):
    img, label = zip(*batch)
    for i, l in enumerate(label):
        if l.shape[0] > 0:
            l[:, 0] = i
    return torch.stack(img), torch.cat(label, 0)
各个变量的含义如下:
  • batch:DataLoader调用batch_size次TensorDataset类的__getitem__函数,获得的返回值放在一起就是batch。
  • img,label:zip(*xxx):利用 * 号操作符,可以将元组解压为列表。__getitem__函数会返回一张图片跟它的标签,图片形状是(通道数,高,宽),简写成(c, h, w),标签是个矩阵,形状是(目标个数,6)。每一个目标本来只有5个属性x, y, w, h, class_id,但是多预留一个属性,用来保存这个目标所在的图片在这一个批次中的索引。
  • 返回值:有俩。第一个是把图片拼起来,形状是(batch_size, c, h, w)。第二个是批次中所有的目标,是一个矩阵,形状是(所有目标数量,6)。

注:“(所有目标数量,6)”中的6,指的就是每个目标的属性:"x,y,w,h,class_id和该图片在这个batch中的索引"

这个函数的返回值就包含了这个批次图片中所有的目标。

1.zip函数

zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。

如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表

zip 语法:

zip([iterable, ...])

参数说明:

iterable -- 一个或多个迭代器;

返回值:

返回元组列表。

实例(Python 2.0+)

>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b)     # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c)              # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped)          # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]

实例(Python 3.0+)

>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b)     # 返回一个对象
>>> zipped
<zip object at 0x103abc288>
>>> list(zipped)  # list() 转换为列表
[(1, 4), (2, 5), (3, 6)]
>>> list(zip(a,c))              # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]

>>> a1, a2 = zip(*zip(a,b))          # 与 zip 相反,zip(*) 可理解为解压,返回二维矩阵式
>>> list(a1)
[1, 2, 3]
>>> list(a2)
[4, 5, 6]
>>>

2.enumerate() 函数

  • enumerate() 函数:用于将一个可迭代的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

再来看看enumerate() 函数的语法结构:

  • enumerate(sequence, [start=0]),其中sequence是一个可迭代序列,start是一个可选参数,表示序列下标的起始位置;
  •  enumerate()函数得到的是多个值,我们需要采用“序列解包”的方式,获取到每一个元素。

判断是否是可迭代对象:

from collections.abc import Iterable

print(isinstance("小甜甜", Iterable))
# True
print(isinstance([1,True,2.3],Iterable))
# True
print(isinstance(1, Iterable))
# False
print(isinstance(range(10), Iterable))
# True

 从中可以看出,列表、元组、字符串都是可迭代对象。

序列解包

通俗的说:就是一次将多个变量赋值给多个值。很简单,不要想的太高深,我们简单举个例子你就知道了。

x, y = (12, 54)
print(x) # 12
print(y) # 54

enumerate() 函数的简单使用

该函数最常就是配合for循环使用,我们就以此为例,为大家演示enumerate() 函数的用法。

需求:打印出班级中大于18岁的同学名字;

如果使用普通的for循环:

i = 0
name = ["张三", "李四", "王五"]
lis = [13, 22, 43]
for element in lis:
    if element >= 18:
        print(i, name[i], lis[i])
        i += 1

如果for循环,配合enumerate()函数使用:

name = ["张三", "李四", "王五"]
lis = [13, 22, 43]
for index, value in enumerate(lis):
    if value >= 18:
        print(index, name[index], value)

3.torch.stack()函数

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠

outputs = torch.stack(inputs, dim=0) → Tensor

1.参数

  • inputs : 待连接的张量序列。 注:python的序列数据只有listtuple
  • dim : 新的维度, 必须在0len(outputs)之间。 注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

2. 重点

  1. 函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

  举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

  2.dim是选择生成的维度,必须满足0<=dim<len(outputs)

  len(outputs)是输出后的tensor的维度大小

3. 例子

1.准备2个tensor数据,每个的shape都是[3,3]

 # 假设是时间步T1
 T1 = torch.tensor([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
 # 假设是时间步T2
 T2 = torch.tensor([[10, 20, 30],
                 [40, 50, 60],
                 [70, 80, 90]])

2.测试stack函数

print(torch.stack((T1,T2),dim=0).shape)
 print(torch.stack((T1,T2),dim=1).shape)
 print(torch.stack((T1,T2),dim=2).shape)
 print(torch.stack((T1,T2),dim=3).shape)
 # outputs:
 torch.Size([2, 3, 3])
 torch.Size([3, 2, 3])
 torch.Size([3, 3, 2])
 '选择的dim > len(outputs),所以报错'
 IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

4.torch.cat()函数

torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。

1.cat()函数

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=0) → Tensor

2.参数

  • inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列。
  • dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列

3.重点

  • 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  • 维度不可以超过输入数据的任一个张量的维度。

准备数据x1,x2

x1:
 tensor([[11, 21, 31],
        [21, 31, 41]], dtype=torch.int32)
x1.shape:
 torch.Size([2, 3])

x2:
 tensor([[12, 22, 32],
        [22, 32, 42]])
x2.shape:
 torch.Size([2, 3])

 cat()处理

inputs = [x1, x2]

R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
 tensor([[11, 21, 31],
        [21, 31, 41],
        [12, 22, 32],
        [22, 32, 42]])
R0.shape:
 torch.Size([4, 3])
'''

R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
 tensor([[11, 21, 31, 12, 22, 32],
        [21, 31, 41, 22, 32, 42]])
R1.shape:
 torch.Size([2, 6])
'''

R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''

 

posted @ 2022-12-30 09:58  Yuxi001  阅读(120)  评论(0编辑  收藏  举报