PySpark UDF

17 Mar 2025 | 5 分钟阅读

Spark SQL 提供了 PySpark UDF (用户定义函数),用于定义新的基于列的函数。 它扩展了 Spark SQL 的 DSL 词汇表,用于转换数据集。

将函数注册为 UDF

我们可以选择设置 UDF 的返回类型。 默认返回类型是 StringType。 考虑以下示例

调用 UDF 函数

PySpark UDF 的功能与 pandas 的 map() 函数和 apply() 函数相同。 这些函数用于 panda 的 series 和 dataframe。 在下面的示例中,我们将创建一个 PySpark 数据帧。

代码将打印数据帧的模式和数据帧。

输出

root
 |-- integers: long (nullable = true)
 |-- floats: double (nullable = true)
 |-- integer_arrays: array (nullable = true)
 |    |-- element: long (containsNull = true)

+--------+------+--------------+
|integers|floats|integer_arrays|
+--------+------+--------------+
|       1|  -1.0|        [1, 2]|
|       2|   0.5|     [3, 4, 5]|
|       3|   2.7|  [6, 7, 8, 9]|
+--------+------+--------------+

评估顺序和空值检查

PySpark SQL 无法保证子表达式的评估顺序保持不变。 并非必须从左到右或以任何其他固定顺序评估运算符或函数的 Python 输入。 例如,逻辑 ANDOR 表达式没有从左到右的“短路”语义。

因此,依赖于布尔表达式的评估顺序非常不安全。 例如,WHEREHAVING 子句的顺序,因为此类表达式和子句可以在查询优化和规划期间重新排序。 如果 UDF 依赖于 SQL 中用于空值检查的短路语义(评估顺序),则不能保证空值检查会在调用 UDF 之前发生。

基本类型输出

让我们考虑一个对数字进行平方的函数 square(),并将此函数注册为 Spark UDF。

现在我们将其转换为 UDF。 在注册时,我们必须使用 pyspark.sql.types. 指定数据类型。 spark UDF 的问题在于它不将整数转换为浮点数,而 Python 函数适用于整数和浮点值。 如果输入数据类型与输出数据类型不匹配,则 PySpark UDF 将返回一列 NULL。 让我们考虑以下程序

输出

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|         null|
|       2|   0.5|          4|         null|
|       3|   2.7|          9|         null|
+--------+------+-----------+-------------+

从上面的输出可以看出,它为浮点数输入返回 null。 现在看看另一个例子。

使用浮点型输出注册 UDF

输出

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|       null|          1.0|
|       2|   0.5|       null|         0.25|
|       3|   2.7|       null|         7.29|
+--------+------+-----------+-------------+

使用 Python 函数指定浮点型输出

在这里,我们强制输出也为整数输入提供浮点数。

输出

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|        1.0|          1.0|
|       2|   0.5|        4.0|         0.25|
|       3|   2.7|        9.0|         7.29|
+--------+------+-----------+-------------+

复合类型输出

如果 Python 函数的输出是列表的形式,则输入值必须是列表,在注册 UDF 时使用 ArrayType() 指定。 考虑以下代码

输出

+--------------+------------------------+
|integer_arrays|(integer_arrays)|
+--------------+------------------------+
|        [1, 2]|              [1.0, 4.0]|
|     [3, 4, 5]|       [9.0, 16.0, 25.0]|
|  [6, 7, 8, 9]|    [36.0, 49.0, 64.0...|
+--------------+------------------------+

一些常见的 UDF 问题

  • Py4JJavaError

这是使用 UDF 时最常见的异常。 它来自 Python 和 Spark 之间的数据类型不匹配。 如果 Python 函数使用来自 Python 模块(如 numpy.ndarray)的数据类型,则 UDF 会抛出异常。

输出

+----------+
|int_arrays|
+----------+
| [1, 2, 3]|
| [4, 5, 6]|
+----------+

在下面的示例中,我们正在创建一个返回 nd.ndarray 的函数。 它们的值也是 Numpy 对象 Numpy.int32 而不是 Python 原语。

输出

array([1, 4, 9], dtype=int32)

如果我们执行以下代码,它将抛出异常 Py4JavaError。

输出

PySpark UDF

此类异常的解决方案是将其转换回其值为 Python 原语的列表。

输出

+----------+------------+
|int_arrays|     squared|
+----------+------------+
| [1, 2, 3]|   [1, 4, 9]|
| [4, 5, 6]|[16, 25, 36]|
+----------+------------+

在上面的代码中,我们描述了异常的解决方案。 现在自己做,观察两个程序之间的差异。

  • 缓慢

PySpark 还有另一个缺点; 与 Python 对应部分相比,它需要很长时间才能运行。 文件大小方面的小数据大小是导致速度慢的原因之一。 Spark 将整个数据帧发送给一个且仅一个执行器,并让其他执行器等待。 解决方案是重新划分数据帧。 例如

当我们重新划分数据时,每个执行器一次处理一个分区,从而减少了执行时间。


下一个主题PySpark RDD