读python代码-学到的python函数-2-collate_fn
一段代码from 'datasets.py':
def collate_fn(batch):
img, label = zip(*batch)
for i, l in enumerate(label):
if l.shape[0] > 0:
l[:, 0] = i
return torch.stack(img), torch.cat(label, 0)
-
batch
:DataLoader调用batch_size
次TensorDataset类的__getitem__函数,获得的返回值放在一起就是batch。
-
img,label
:zip(*xxx):利用 * 号操作符,可以将元组解压为列表。__getitem__函数会返回一张图片跟它的标签,图片形状是(通道数,高,宽),简写成(c, h, w),标签是个矩阵,形状是(目标个数,6)。每一个目标本来只有5个属性x, y, w, h, class_id,但是多预留一个属性,用来保存这个目标所在的图片在这一个批次中的索引。 -
返回值:有俩。第一个是把图片拼起来,形状是(batch_size, c, h, w)。第二个是批次中所有的目标,是一个矩阵,形状是(所有目标数量,6)。
注:“(所有目标数量,6)”中的6,指的就是每个目标的属性:"x,y,w,h,class_id和该图片在这个batch中的索引"
1.zip函数
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
zip 语法:
zip([iterable, ...])
参数说明:
iterable -- 一个或多个迭代器;
返回值:
返回元组列表。
实例(Python 2.0+)
>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c) # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]
实例(Python 3.0+)
>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 返回一个对象
>>> zipped
<zip object at 0x103abc288>
>>> list(zipped) # list() 转换为列表
[(1, 4), (2, 5), (3, 6)]
>>> list(zip(a,c)) # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> a1, a2 = zip(*zip(a,b)) # 与 zip 相反,zip(*) 可理解为解压,返回二维矩阵式
>>> list(a1)
[1, 2, 3]
>>> list(a2)
[4, 5, 6]
>>>
2.enumerate() 函数
- enumerate() 函数:用于将一个可迭代的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
再来看看enumerate() 函数的语法结构:
- enumerate(sequence, [start=0]),其中sequence是一个可迭代序列,start是一个可选参数,表示序列下标的起始位置;
- enumerate()函数得到的是多个值,我们需要采用“序列解包”的方式,获取到每一个元素。
判断是否是可迭代对象:
from collections.abc import Iterable
print(isinstance("小甜甜", Iterable))
# True
print(isinstance([1,True,2.3],Iterable))
# True
print(isinstance(1, Iterable))
# False
print(isinstance(range(10), Iterable))
# True
从中可以看出,列表、元组、字符串都是可迭代对象。
序列解包
通俗的说:就是一次将多个变量赋值给多个值。很简单,不要想的太高深,我们简单举个例子你就知道了。
x, y = (12, 54)
print(x) # 12
print(y) # 54
enumerate() 函数的简单使用
该函数最常就是配合for循环使用,我们就以此为例,为大家演示enumerate() 函数的用法。
需求:打印出班级中大于18岁的同学名字;
如果使用普通的for循环:
i = 0
name = ["张三", "李四", "王五"]
lis = [13, 22, 43]
for element in lis:
if element >= 18:
print(i, name[i], lis[i])
i += 1
如果for循环,配合enumerate()函数使用:
name = ["张三", "李四", "王五"]
lis = [13, 22, 43]
for index, value in enumerate(lis):
if value >= 18:
print(index, name[index], value)
3.torch.stack()函数
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=0) → Tensor
1.参数
- inputs : 待连接的张量序列。 注:python的序列数据只有
list
和tuple
。 - dim : 新的维度, 必须在
0
到len(outputs)
之间。 注:len(outputs)
是生成数据的维度大小,也就是outputs
的维度值。
2. 重点
- 函数中的输入
inputs
只允许是序列;且序列内部的张量元素,必须shape
相等
举例:[tensor_1, tensor_2,..]
或者(tensor_1, tensor_2,..)
,且必须tensor_1.shape == tensor_2.shape
2.dim
是选择生成的维度,必须满足0<=dim<len(outputs)
;
len(outputs)
是输出后的tensor
的维度大小
3. 例子
1.准备2个tensor
数据,每个的shape
都是[3,3]
# 假设是时间步T1
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
2.测试stack函数
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'选择的dim > len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
4.torch.cat()函数
torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。
1.cat()函数
函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
2.参数
- inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列。
- dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。
3.重点
- 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
- 维度不可以超过输入数据的任一个张量的维度。
准备数据x1,x2
x1:
tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32)
x1.shape:
torch.Size([2, 3])
x2:
tensor([[12, 22, 32],
[22, 32, 42]])
x2.shape:
torch.Size([2, 3])
cat()处理
inputs = [x1, x2]
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
tensor([[11, 21, 31],
[21, 31, 41],
[12, 22, 32],
[22, 32, 42]])
R0.shape:
torch.Size([4, 3])
'''
R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
tensor([[11, 21, 31, 12, 22, 32],
[21, 31, 41, 22, 32, 42]])
R1.shape:
torch.Size([2, 6])
'''
R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''