pandas_udf使用说明
摘要
Spark2.0 推出了一个新功能pandas_udf
,本文结合spark 官方文档和自己的使用情况,讲解pandas udf
的基本知识,并添加实例,方便初学的同学快速上手和理解。
Apche Arrow
ApacheArrow 是一种内存中的列式数据格式,用于在 Spark 中 JVM 和 Python 进程之间数据的高效传输。这对于使用 pandas/numpy 数据的 python 用户来说是最有利的。它的使用不是自动的,可能需要对配置或代码进行一些细微的更改,以充分利用并确保兼容性。
Apche Arrow 的安装
在pyspark
安装的时候,Apche Arrow
就已经安装了,可能安装的版本比较低,在你使用pandas udf
的时候会报如下的错误,
1
|
"it was not found." % minimum_pyarrow_version)
|
可以从报错信息中发现是Arrow
的版本过低了,可以通过pip install pyspark
进行安装或更新。
使用 Arrow 对 spark df 与 pandas df 的转换
Arrow
能够优化spark df
与pandas df
的相互转换,在调用Arrow
之前,需要将 spark 配置spark.sql.execution.arrow.enabled
设置为`true。这在默认情况下是禁用的。
此外,如果在 spark 实际计算之前发生错误,spark.sql.execution.arrow.enabled
启用的优化会自动回退到非 Arrow 优化实现。这可以有spark.sql.execution.arrow.fallback.enabled
来控制。
对 arrow 使用优化将产生与未启用 arrow 时相同的结果。但是,即使使用 arrow,toPandas()
也会将数据中的所有记录收集到驱动程序中,所以应该在数据的一小部分中使用。
目前,并非所有 spark 数据类型都受支持,如果列的类型不受支持,则可能会引发错误,请参阅受支持的 SQL 类型。如果在create dataframe()
期间发生错误,spark 将返回非 Arrow 优化实现的数据。
设置与转换
1
|
import numpy as np
|
Pandas UDFs (a.k.a Vectorized UDFs)
pandas udf
是用户定义的函数,是由 spark 用arrow
传输数据,pandas
去处理数据。我们可以使用pandas_udf
作为decorator
或者registor
来定义一个pandas udf
函数,不需要额外的配置。目前,pandas udf
有三种类型:标量映射(Scalar
)和分组映射(Grouped Map
)和分组聚合(Grouped Aggregate
)。
-
Scalar
其用于向量化标量操作。它们可以与
select
和withColumn
等函数一起使用。python 函数应该以pandas.series
作为输入,并返回一个长度相同的pandas.series
。在内部,spark 将通过将列拆分为batch
,并将每个batch
的函数作为数据的子集调用,然后将结果连接在一起,来执行 padas UDF。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType
# Declare the function and create the UDF
def (a, b):
return a * b
# or multiply = pandas_udf(multiply_func, returnType=LongType())
# Create a Spark DataFrame, 'spark' is an existing SparkSession
df = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))
# Execute function as a Spark vectorized UDF
df.select(multiply_func(col("x"), col("x"))).show()
# +-------------------+
# |multiply_func(x, x)|
# +-------------------+
# | 1|
# | 4|
# | 9|
# +-------------------+
-
Grouped Map
Grouped Map
pandas_udf
与groupBy().apply()
一起使用,后者实现了split-apply-combine
模式。拆分应用组合包括三个步骤:df.groupBy()
对数据分组apply()
对每个组进行操作,输入和输出都是 dataframe 格式- 汇总所有结果到一个 dataframe 中
使用
groupBy().apply()
,用户需要定义以下内容:- 一个函数,放在
apply()
里 - 一个输入输出的
schema
,两者必须相同
请注意,在应用函数之前,组的所有数据都将加载到内存中。这可能导致内存不足异常,尤其是当组的大小
skwed
的时候。maxRecordsPerBatch
不适用于这里。所以,用户需要来确保分组的数据适合可用内存。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23from pyspark.sql.functions import pandas_udf, PandasUDFType
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
# or df.schema
@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def subtract_mean(pdf):
# pdf is a pandas.DataFrame
v = pdf.v
return pdf.assign(v = v - v.mean())
df.groupby("id").apply(subtract_mean).show()
# +---+----+
# | id| v|
# +---+----+
# | 1|-0.5|
# | 1| 0.5|
# | 2|-3.0|
# | 2|-1.0|
# | 2| 4.0|
# +---+----+
-
Grouped Aggregate
其类似于 Spark 聚合函数。使用
groupby().agg()
和pyspark.sql.Window
一起使用。它定义从一个或多个pandas.series
到一个标量值的聚合,其中每个pandas.series
表示组中的一列或窗口。请注意,这种类型的 UDF 不支持部分聚合,组或窗口的所有数据都将加载到内存中。此外,这种类型只接受
unbounded window
,也就是说,我们不能定义window size
。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql import Window
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def mean_udf(v):
return v.mean()
df.groupby("id").agg(mean_udf(df['v'])).show()
# +---+-----------+
# | id|mean_udf(v)|
# +---+-----------+
# | 1| 1.5|
# | 2| 6.0|
# +---+-----------+
w = Window
.partitionBy('id')
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()
# +---+----+------+
# | id| v|mean_v|
# +---+----+------+
# | 1| 1.0| 1.5|
# | 1| 2.0| 1.5|
# | 2| 3.0| 6.0|
# | 2| 5.0| 6.0|
# | 2|10.0| 6.0|
# +---+----+------+ -
结合使用
如果想用
agg()
的思想,又想定义window size
,我们可以用 Group Map,并在pandas udf function
中使用 pandas 的rolling()
来实现。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23from pyspark.sql.functions import lit, pandas_udf, PandasUDFType
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
df = df.withColumn("mv", f.lit(0.))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def moving_mean(pdf):
v = pdf.v
pdf['mv'] = v.rolling(3,1).mean()
return pdf
df.groupby("id").apply(moving_mean).show()
# +---+----+---+
# | id| v| mv|
# +---+----+---+
# | 1| 1.0|1.0|
# | 1| 2.0|1.5|
# | 2| 3.0|3.0|
# | 2| 5.0|4.0|
# | 2|10.0|6.0|
# +---+----+---+
其他使用说明
-
支持的 SQL 类型
目前,Arrow-base 的转换支持所有的 spark sql 数据类型,除了
MapType
,ArrayTpye
中的TimestampType
和nested StructType
。BinaryType
仅在 Arrow 版本大于等于0.10.0
时被支持。 -
设置
Arrow Batch Size
spark 中的数据分区被转换成 arrow 记录批处理,这会暂时导致 JVM 中的高内存使用率。为了避免可能的内存不足异常,可以通过 conf 的
spark.sql.execution.arrow.maxRecordsPerBatch
设置为一个整数来调整 Arrow 记录 batch 的大小,该整数将确定每个 batch 的最大行数。默认值为 10000 条记录。如果列数较大,则应相应调整该值。使用这个限制,每个数据分区将被制成一个或多个记录 batch 处理。 -
Timestamp 的时区问题
Spark 内部将
Timestamp
存储为 UTC 值,在没有指定时区的情况下引入的Timestamp
数据将转换为具有微秒分辨率的本地时间到 UTC。在 spark 中导出或显示Timestamp
数据时,会话时区用于本地化Timestamp
值。会话时区是使用配置spark.sql.session.time zone
设置的,如果不设置,则默认为 JVM 系统本地时区。pandas 使用具有纳秒分辨率的datetime64
类型,datetime64[ns]
,每个列上都有可选的时区。当
Timestamp
数据从 spark 传输到 pandas 时,它将被转换为纳秒,并且每一列将被转换为 spark 会话时区,然后本地化到该时区,该时区将删除时区并将值显示为本地时间。当使用Timestamp
列调用toPandas()
或pandas_udf
时,会发生这种情况。当
Timestamp
数据从 pandas 传输到 spark 时,它将转换为 UTC 微秒。当使用 pandas dataframe 调用CreateDataFrame
或从 pandas dataframe 返回Timestamp
时,会发生这种情况。这些转换是自动完成的,以确保 Spark 具有预期格式的数据,因此不需要自己进行这些转换。任何纳秒值都将被截断。请注意,标准 UDF(非 PANDAS)将以 python 日期时间对象的形式加载
Timestamp
数据,这与 PANDAS 的Timestamp
不同。建议使用 pandas 的Timestamp
时使用 pandas 的时间序列功能,以获得最佳性能,有关详细信息,请参阅此处。 -
Pandas udf 其他使用案例