Python:变长和定长序列拆分
1 导引
Python中的任何序列(可迭代的对象)都可以通过赋值操作进行拆分,包括但不限于元组、列表、字符串、文件、迭代器、生成器等。
2 元组拆分
元组拆分是最为常见的一种拆分,示例如下:
p = (4, 5)
x, y = p
print(x, y) # 4 5
如果写成
x, y, z = p
那么就会抛出ValueError
异常:"not enough values to unpack (expected 3, got 2)"
。
如果写成
p = (4, 5, 6)
x, y = p
那么就会抛出ValueError
异常:"too many values to unpack (expected 2)"
。
元组拆包无处不在,比如我们知道Python的zip()
函数相当于返回一个元组列表的迭代器,我们可以对该元组列表进行迭代拆分:
list_a = [1, 3, 5]
list_b = [2, 4, 6]
print(list(zip(list_a, list_b)))
#[(1, 2), (3, 4), (5, 6)]
for a, b in zip(list_a, list_b):
print("%d - %d" % (a, b))
# 1 - 2
# 3 - 4
# 5 - 6
上面的迭代语句其实隐式等价于for (a, b) in zip(list_a, list_b)
。
接下来容易出错的地方来了,很多时候我们会将附加索引下标的上述迭代错误地这样写:
for idx, a, b in enumerate(zip(list_a, list_b)):
print("idx%d: %d - %d" % (idx, a, b))
此时就会抛出ValueError
异常:"not enough values to unpack (expected 3, got 2)"
。我们经过上面的讨论知道,这是不正确地元组拆包所致。
原来,迭代enumerate(zip(list_a, list_b))
实际等价于迭代[(0, (1, 2)), (1, (3, 4)), (2, (5, 6))]
:
print(list(enumerate(zip(list_a, list_b))))
# [(0, (1, 2)), (1, (3, 4)), (2, (5, 6))]
对其迭代需要进行两次复合的元组拆包,即:
for idx, (a, b) in enumerate(zip(list_a, list_b)):
print("idx%d: %d - %d" % (idx, a, b))
# idx0: 1 - 2
# idx1: 3 - 4
# idx2: 5 - 6
还是同样地,两次拆包有一次隐式省略,上述迭代语句隐式等价于for (idx, (a, b)) in enumerate(zip(list_a, list_b)):
。
这里值得一提的是,上面说了用zip()
函数+list()
函数可以让我们获得一个元组列表,该操作的在机器学习项目的场景下非常实用,因为我们已知一堆点的\(x\)坐标列表和\(y\)坐标列表,我们可以通过zip()
函数+list()
函数的形式获得\((x,y)\)坐标列表。然而,如果我们已知\((x,y)\)坐标列表,如何快速恢复出\(x\)坐标列表和\(y\)坐标列表呢?我们可以这样写:
points = [(1, 2), (3, 4), (5, 6)]
x, y = zip(*points)
print(x) # (1, 3, 5)
print(y) # (2, 4, 6)
这里有个*
运算符读者可能感到陌生,这表示将points
列表中的所有元素以位置参数的形式传入zip()
函数(读者可以参见我的博客《Python:位置参数、关键字参数和接受任意数量的参数》),而zip(*points)
实际上等价于
zip((1, 2), (3, 4), (5, 6))
而我们前面说过,迭代上述zip()
函数返回的迭代器实质上等于迭代元组列表[(1, 3, 5), (2, 4, 6)]
。因为该元组列表只有两个元素,故我们可以直接对该列表进行拆包,于是得到了拆包结果(1, 3, 5)
和(2, 4, 6)
。从这个视角看,zip(*)
操作可以理解将二维数据沿纵向拆分成列向量。
这种写法有个巨大的应用场景就是处理机器学习的训练数据。比如,假设我们在做一个机器学习项目,有下列训练数据X
和训练数据Y
:
import numpy as np
X = np.random.rand(5, 3)
Y = np.random.randint(0, 2, size=(5, 1))
print(X)
# [[0.20447277 0.85066912 0.3331559 ]
# [0.78313617 0.78667579 0.17555529]
# [0.67388656 0.75179676 0.58292836]
# [0.12512522 0.5669724 0.45970325]
# [0.61955282 0.64029496 0.93385069]]
print(Y)
# [[0]
# [1]
# [1]
# [1]
# [0]]
我们接下来想不借助scikit-learn
库中的sklearn.utils.shuffle
函数,仅仅使用numpy
包和Python内置函数来优雅地完成对数据集的shuffle操作,那么该怎么做呢?首先,直接写
np.random.shuffle(X)
np.random.shuffle(Y)
是不行的,因为这样会丢失样本数据x
和y
的一一对应关系。事实上,我们可以先试用zip
函数将原始的X
和Y
数据转换成(x, y)
二元组组成的坐标列表:
x_y_pair = list(zip(X, Y))
print(x_y_pair)
# [(array([0.36742827, 0.02156507, 0.07500242]), array([1])),
# (array([0.6562936 , 0.7262091 , 0.50394983]), array([0])),
# (array([0.02043896, 0.08081809, 0.5199801 ]), array([0])),
# (array([0.87178023, 0.06728234, 0.54260044]), array([1])),
# (array([0.81271828, 0.50946797, 0.02489041]), array([1]))]
然后在此基础上进行shuffle:
np.random.shuffle(x_y_pair)
print(x_y_pair)
# [(array([0.6562936 , 0.7262091 , 0.50394983]), array([0])),
# (array([0.81271828, 0.50946797, 0.02489041]), array([1])),
# (array([0.02043896, 0.08081809, 0.5199801 ]), array([0])),
# (array([0.87178023, 0.06728234, 0.54260044]), array([1])),
# (array([0.36742827, 0.02156507, 0.07500242]), array([1]))]
然后,我们再借用zip(*)
和np.stack()
组合操作得到拼接完成的数据集:
X = np.stack(list(zip(*x_y_pair))[0])
Y = np.stack(list(zip(*x_y_pair))[1])
print(X)
# [[0.6562936 0.7262091 0.50394983]
# [0.81271828 0.50946797 0.02489041]
# [0.02043896 0.08081809 0.5199801 ]
# [0.87178023 0.06728234 0.54260044]
# [0.36742827 0.02156507 0.07500242]]
print(Y)
# [[0]
# [1]
# [0]
# [1]
# [1]]
正如我们前面所说的,这里list(zip(*x_y_pair))
沿纵向将x_y_pair
拆分成x
和y
这两部分,得到了一个由x
向量组成的元组(array([0.6562936,...]), array([0.81271828,...]), ..., array([0.36742827, ...))
和一个由y
构成的元组(array([0]), array([1]), ..., array([1]))
,然后我们再将由x
向量构成的元组和y
构成的元组进行stack操作,就还原了我们的X
和Y
数据
PS:这里
np.stack()
是对数据的堆叠(会增加一个额外维度),比如对一维的数据(shape为(n, )
)就是堆叠得到一个新的二维数据;而np.concatenate()
则是需要指定一个维度进行拼接(不会增加额外维度),对一维数据就是拼接得到一个新的一维数据:
import numpy as np
a = np.array([1, 2, 3])
b = np.array([3, 4, 5])
res1 = np.stack([a, b])
res2 = np.concatenate([a, b])
print(res1)
# [[1 2 3]
# [3 4 5]]
print(res2)
# [1 2 3 3 4 5]
下面是在多维数据(shape为
(n, 1)
或(n, m
)下的情况:
c = np.array([[1, 2, 3]])
d = np.array([[4, 5, 6]])
res3 = np.stack([c, d])
res4 = np.concatenate([c, d])
print(res3)
# [[[1 2 3]]
# [[4 5 6]]]
print(res4)
# [[1 2 3]
# [4 5 6]]
e = np.array([[1, 2], [3, 4]])
f = np.array([[5, 6], [7, 8]])
res5 = np.stack([e, f])
res6 = np.concatenate([e, f])
print(res5)
# [[[1 2]
# [3 4]]
# [[5 6]
# [7 8]]]
print(res6)
# [[1 2]
# [3 4]
# [5 6]
# [7 8]]
可见和一维的情况一样,
np.stack
会增加额外的维度,np.concatenate()
则不会。此二者都是默认axis=0
,即沿着维度0
的方向堆叠/拼接。
好了,现在言归正传,回到我们关于Python元组拆分的讨论。其实,Python中所谓函数能返回多个值,其实是返回的元组,如下面这种所示:
def func():
return 1, 2, 3
实际上等同于返回(1, 2, 3)
元组。我们可以选择直接接收该元组对象:
my_tuple = func()
print(my_tuple) # (1, 2, 3)
注意,上面这个代码中my_tuple
为一个引用,引用在函数体内部创建的元组对象(如对此有疑问,可参见我的博客《Python对象模型与序列迭代陷阱 》)。
当然,也可以将元组拆包接收:
a, b, c = func()
print(a, b, c) # 1 2 3
但是注意,如果要拆包必须要保证拆包正确,像下面这种写法:
a, b = func()
无疑就会抛出ValueError
异常:"too many values to unpack (expected 2)"
了。
3 字符串拆分
字符串的拆分示意如下:
s = 'Hello'
a, b, c, d, e = s
print(a) # H
4 拆分时丢弃值
如果在拆分时想丢弃某些特定的值,可以用一个用不到的变量名来作为丢弃值的名称(常选_
做为变量名),如下所示:
s = 'Hello'
a, b, _, d, _ = s
print(a) # H
5 嵌套序列拆分
Python也提供简洁的对嵌套序列进行拆分的语法。如下所示我们对一个比较复杂的异质列表进行拆分:
data = ['zhy', 50, 123.0, (2000, 12, 21)]
name, shares, price, (year, month, day) = data
print(year) # 2000
如果你想完整地得到(2000, 12, 21)
这个表示时间戳的元组,那么你就得这样写:
data = ['zhy', 50, 123.0, (2000, 12, 21)]
name, shares, price, date = data
print(date) # (2000, 12, 21)
6 从任意长度的可迭代对象中拆分
之前我们说过,如果我们想从可迭代对象中分解出\(N\)个元素,但如果这个可迭代对象长度超过\(N\),则会抛出异常"too many values to unpack"
。针对这个问题的解决方案是采用*
表达式。
比如我们给定学生的分数,想去掉一个最高分和一个最低分,然后对剩下的学生求平均分,我们可以这样写:
def avg(data: list):
return sum(data)/len(data)
# 去掉最高分,最低分然后做均分统计
def drop_first_last(grades):
first, *middle, last = grades
return avg(middle)
print(drop_first_last([1,2,3,4])) # 2.5
还有一种情况是有一些用户记录,记录由姓名+电子邮件+任意数量的电话号码组成,则我们可以这样分解用户记录:
record = ['zhy', 'zhy1056692290@qq.com', '773-556234', '774-223333']
name, email, *phone_numbers = record
print(phone_numbers) # ['773-556234', '774-223333']
事实上,如果电话号码为空也是合法的,此时phone_numbers为空列表。
record = ['zhy', 'zhy1056692290@qq.com']
name, email, *phone_numbers = record
print(phone_numbers) # []
还有一种使用情况则更为巧妙。如果我们需要遍历变长元组组成的列表,这些元组长度不一。那么此时*
表达式可大大简化我们的代码。
records = [('foo', 1, 2), ('bar', 'hello'), ('foo', 3, 4)]
for tag, *args in records:
if tag == 'bar':
print(args)
# ['hello']
在对一些复杂的字符串进行拆分时,*
表达式也显得特别有用。
line = "nobody:*:-2:-2:-2:Unprivileged User:/var/empty:/usr/bin/false"
uname, *fields, home_dir, sh = line.split(':')
print(home_dir) # /var/empty
*
表达式也可以和我们前面说的嵌套拆分和变量丢弃一起结合使用。
record = ['ACME', 50, 123.45, (128, 18, 2012)]
name, *_, (*_, year) = record
print(year) # 2012
最后再介绍*
表达式用于递归函数的一种黑魔法,比如与递归求和结合可以这样写:
items = [1, 10, 7, 4, 5, 9]
def sum(items):
head, *tail = items
return head + sum(tail) if tail else head
print(sum(items)) # 36
不过,Python由于自身递归栈的限制,并不擅长递归。我们最后一个递归的例子可以做为一种学术上的尝试,但不建议在实践中使用它。
参考
- [1] Beazley D, Jones B K. Python cookbook: Recipes for mastering Python 3[M]. " O'Reilly Media, Inc.", 2013.
- [2] https://stackoverflow.com/questions/16326853/enumerate-two-python-lists-simultaneously
- [3] https://www.python.org/