假设我们有以下数据:
from pyspark.sql import functions as F
df = sc.parallelize([(1,["what", "sad", "to", "me"]),
(1,["what", "what", "does", "the"]),
(2,["at", "the", "oecd", "with"])]).toDF(["id", "message"])
df.show(truncate = False)
+---+-----------------------+
|id |message |
+---+-----------------------+
|1 |[what, sad, to, me] |
|1 |[what, what, does, the]|
|2 |[at, the, oecd, with] |
+---+-----------------------+
我们需要统计每一个id的不同单词出现的次数,一般有两种思路:一种是先将同一id的单词数组聚合为一个数组,再去统计次数;另一种是将数组拆分为单个单词,对同一id统计各个单词的次数,再把结果聚合到一起。
1 先聚合再统计
先来看第一种思路,先通过collect_list函数将单词数组聚合到一起:
df.groupBy("id").agg(F.collect_list("message").alias("message")).show(truncate = False)
得到结果如下:
+---+----------------------------------------------+
|id |message |
+---+----------------------------------------------+
|1 |[[what, sad, to, me], [what, what, does, the]]|
|2 |[[at, the, oecd, with]] |
+---+----------------------------------------------+
接下来需要思考如何将数组中的数组元素拆出来。联想到列表的迭代操作,可以用如下代码:
unpack_udf = F.udf(lambda l: [item for sublist in l for item in sublist])
假设有一个列表list_1,通常的迭代操作为:[item for item in list_1]
。如果list_2是一个嵌套列表,则可以用:[item for list_it in list_2 for item in list_it]
。这可以简单理解为一个双层循环:
for list_it in list_2:
for item in list_it:
item
于是可以将嵌套数组元素拆出来变成一个单层的数组:
df.groupBy("id").agg(F.collect_list("message").alias("message"))\
.withColumn("message", unpack_udf("message"))\
.show(truncate = False)
+---+------------------------------------------+
|id |message |
+---+------------------------------------------+
|1 |[what, sad, to, me, what, what, does, the]|
|2 |[at, the, oecd, with] |
+---+------------------------------------------+
接下来统计每个id各自数组中的单词次数。python中collections库中有个Counter函数可以用来统计次数,用于pyspark中需要用udf包装一下,代码如下:
from collections import Counter
from pyspark.sql.types import *
schema_count = ArrayType(StructType([
StructField("word", StringType(), False),
StructField("count", IntegerType(), False)]))
count_udf = F.udf(
lambda s: Counter(s).most_common(),
schema_count)
df.groupBy("id").agg(F.collect_list("message").alias("message"))\
.withColumn("message", unpack_udf("message"))\
.withColumn("message", count_udf("message"))\
.show(truncate = False)
结果如下:
+---+------------------------------------------------------------+
|id |message |
+---+------------------------------------------------------------+
|1 |[[what, 3], [sad, 1], [to, 1], [me, 1], [does, 1], [the, 1]]|
|2 |[[at, 1], [the, 1], [oecd, 1], [with, 1]] |
+---+------------------------------------------------------------+
Counter函数和udf函数的使用参考文末链接。
2 先统计再聚合
先将单词拆分,如下:
df.withColumn("word", explode("message")).show()
+---+-----------------------+----+
|id |message |word|
+---+-----------------------+----+
|1 |[what, sad, to, me] |what|
|1 |[what, sad, to, me] |sad |
|1 |[what, sad, to, me] |to |
|1 |[what, sad, to, me] |me |
|1 |[what, what, does, the]|what|
|1 |[what, what, does, the]|what|
|1 |[what, what, does, the]|does|
|1 |[what, what, does, the]|the |
|2 |[at, the, oecd, with] |at |
|2 |[at, the, oecd, with] |the |
|2 |[at, the, oecd, with] |oecd|
|2 |[at, the, oecd, with] |with|
+---+-----------------------+----+
拆完之后对每个id统计各个单词次数:
df.withColumn("word", F.explode("message")).groupBy("id", "word").count().show(truncate = False)
+---+----+-----+
|id |word|count|
+---+----+-----+
|1 |me |1 |
|1 |what|3 |
|2 |at |1 |
|2 |oecd|1 |
|2 |with|1 |
|1 |sad |1 |
|1 |the |1 |
|1 |to |1 |
|1 |does|1 |
|2 |the |1 |
+---+----+-----+
之后将统计的单词和次数合并到一起:
df.withColumn("word", F.explode("message")).groupBy("id", "word").count().withColumn("message",F.struct("word","count")).show(truncate = False)
+---+----+-----+---------+
|id |word|count|message |
+---+----+-----+---------+
|1 |me |1 |[me, 1] |
|1 |what|3 |[what, 3]|
|2 |at |1 |[at, 1] |
|2 |oecd|1 |[oecd, 1]|
|2 |with|1 |[with, 1]|
|1 |sad |1 |[sad, 1] |
|1 |the |1 |[the, 1] |
|1 |to |1 |[to, 1] |
|1 |does|1 |[does, 1]|
|2 |the |1 |[the, 1] |
+---+----+-----+---------+
最后将每个id的统计结果聚合到一起:
df.withColumn("word", F.explode("message")) \
.groupBy("id", "word").count() \
.groupBy("id") \
.agg(F.collect_list(F.struct("word", "count")).alias("message"))\
.show(truncate = False)
+---+------------------------------------------------------------+
|id |message |
+---+------------------------------------------------------------+
|1 |[[me, 1], [what, 3], [sad, 1], [the, 1], [to, 1], [does, 1]]|
|2 |[[at, 1], [oecd, 1], [with, 1], [the, 1]] |
+---+------------------------------------------------------------+
结果与前一种方法一样。完整的代码如下:
from pyspark.sql import functions as F
from collections import Counter
from pyspark.sql.types import *
df = sc.parallelize([(1,["what", "sad", "to", "me"]),
(1,["what", "what", "does", "the"]),
(2,["at", "the", "oecd", "with"])]).toDF(["id", "message"])
df.show(truncate = False)
# 方法一
schema_count = ArrayType(StructType([
StructField("word", StringType(), False),
StructField("count", IntegerType(), False)]))
count_udf = F.udf(
lambda s: Counter(s).most_common(), schema_count)
df.groupBy("id").agg(F.collect_list("message").alias("message"))\
.withColumn("message", unpack_udf("message"))\
.withColumn("message", count_udf("message"))\
.show(truncate = False)
# 方法二
df.withColumn("word", F.explode("message")) \
.groupBy("id", "word").count() \
.groupBy("id") \
.agg(F.collect_list(F.struct("word", "count")).alias("message"))\
.show(truncate = False)
参考链接: