目录
Shuffle 核心概念
什么是 Shuffle?
Shuffle = 数据重新洗牌,跨节点传输
节点1: [A,B,C,A,B] 节点1: [A,A,A,A]
节点2: [B,C,A,C,B] → 节点2: [B,B,B,B]
节点3: [C,A,B,C,A] 节点3: [C,C,C,C]
相同key聚到同一节点
代价: 磁盘IO + 网络IO + 序列化/反序列化
触发 Shuffle 的操作
✅ 必然触发
// 聚合类
reduceByKey(_ + _)
groupByKey()
aggregateByKey()
combineByKey()
// 排序类
sortByKey()
sortBy()
// 连接类
join()
leftOuterJoin()
rightOuterJoin()
cogroup()
// 去重类
distinct()
// 重分区
repartition()
partitionBy()
❌ 不触发
map()
filter()
flatMap()
mapPartitions()
union()
coalesce(n, shuffle=false) // 减少分区不Shuffle
Shuffle 执行流程
三阶段模型
Stage N (Map阶段)
↓
[Shuffle Write] 写磁盘
↓
[Shuffle Transfer] 网络传输
↓
[Shuffle Read] 读数据
↓
Stage N+1 (Reduce阶段)
详细执行流程
以 reduceByKey 为例:
rdd.map(x => (x.key, x.value))
.reduceByKey(_ + _)
1. Map端(Shuffle Write)
Task 1:
数据: [("a",1), ("b",2), ("a",3)]
↓
按key分桶 (Hash分区)
↓
桶0: [("a",1), ("a",3)] → 写文件 shuffle_0_0
桶1: [("b",2)] → 写文件 shuffle_0_1
关键优化: Map端预聚合
// reduceByKey 在Map端先聚合
("a",1), ("a",3) → ("a",4) // 减少网络传输量
2. Shuffle传输
Map端磁盘文件
↓ 网络拉取
Reduce端内存
元数据管理: MapOutputTracker 记录文件位置
Task 1 → shuffle_0_0 在 节点A:/path/file1
Task 2 → shuffle_1_0 在 节点B:/path/file2
3. Reduce端(Shuffle Read)
Reduce Task 0:
拉取所有 桶0 的数据
↓
[("a",4), ("a",2), ("a",5)]
↓
最终聚合 → ("a", 11)
Shuffle 优化参数
内存配置
// Shuffle内存占比(已废弃,使用统一内存管理)
spark.shuffle.memoryFraction = 0.2
// Map端聚合缓冲区
spark.shuffle.file.buffer = 32k
// Reduce端拉取缓冲区
spark.reducer.maxSizeInFlight = 48m
// Reduce端聚合内存
spark.shuffle.spill.compress = true
并行度配置
// 全局默认并行度
spark.default.parallelism = 200
// SQL Shuffle分区数(重要!)
spark.sql.shuffle.partitions = 200
// 代码中指定
rdd.reduceByKey(_ + _, 300) // 300个分区
推荐值: CPU核心总数 × 2~4
压缩配置
// 启用Shuffle压缩
spark.shuffle.compress = true
spark.shuffle.spill.compress = true
// 压缩算法选择
spark.io.compression.codec = snappy // 速度快(推荐)
// 可选: lz4, lzf, zstd
压缩算法对比:
snappy: 速度快,压缩率中等(默认推荐)lz4: 速度最快,压缩率低zstd: 压缩率高,速度较慢
其他重要参数
// Shuffle文件合并
spark.shuffle.consolidateFiles = true
// 网络超时
spark.network.timeout = 120s
// 重试次数
spark.shuffle.io.maxRetries = 3
// 重试等待时间
spark.shuffle.io.retryWait = 5s
数据倾斜问题
什么是数据倾斜?
Task 1: 处理 1万条 → 1秒完成
Task 2: 处理 100万条 → 10分钟完成 ⚠️
Task 3: 处理 2万条 → 2秒完成
整个Stage耗时 = 10分钟(木桶效应)
表现:
- 某几个Task执行时间特别长
- 大量Task已完成,少数Task卡住
- 内存溢出(OOM)错误
倾斜原因
// 某个key数据量巨大
val data = [
("正常key1", 100条),
("正常key2", 120条),
("热点key", 100万条), // ⚠️ 倾斜源头
("正常key3", 90条)
]
data.reduceByKey(_ + _)
// "热点key" 全部进入同一个分区
常见场景:
- 热门商品、热门用户
- NULL值过多
- 业务特性导致(如地域分布不均)
数据倾斜解决方案
方案1: 过滤热点key
// 如果热点key可以忽略
val filtered = rdd.filter(_._1 != "热点key")
.reduceByKey(_ + _)
// 热点key单独处理
val hotkey = rdd.filter(_._1 == "热点key")
.map(_._2).reduce(_ + _)
// 合并结果
val result = filtered.union(sc.parallelize(Seq(("热点key", hotkey))))
适用场景:
- 少量已知热点key
- 热点key可以单独计算
方案2: 加盐打散(最常用)⭐
// 原始数据
("apple", 1) → Hash分区 → 分区3
("apple", 2) → Hash分区 → 分区3 // 都在分区3
// 加盐拆分
val saltNum = 10
val salted = rdd.map { case (k, v) =>
val salt = Random.nextInt(saltNum) // 0-9随机
((k + "_" + salt), v)
}
// ("apple_0", 1) → 分区3
// ("apple_7", 2) → 分区8 // 分散了!
// 第一次聚合
val reduced1 = salted.reduceByKey(_ + _)
// ("apple_0", 100)
// ("apple_7", 200)
// 去盐,第二次聚合
val result = reduced1.map { case (k, v) =>
(k.split("_")(0), v)
}.reduceByKey(_ + _)
// ("apple", 300)
优点: 通用性强,效果好
缺点: 两次Shuffle,计算量增加
适用: 未知热点key,严重倾斜
方案3: 预聚合优化
// ❌ 差:直接groupByKey
rdd.groupByKey() // 传输所有数据
.mapValues(_.sum)
// ✅ 好:用reduceByKey
rdd.reduceByKey(_ + _) // Map端预聚合,减少传输
原理对比:
// groupByKey
Map端: ("a",1), ("a",2), ("a",3)
↓ 全部传输(3条数据)
Reduce端: ("a", [1,2,3]) → sum = 6
// reduceByKey
Map端: ("a",1), ("a",2), ("a",3)
↓ 预聚合 → ("a",6)
↓ 只传输(1条数据)
Reduce端: ("a", 6)
类似优化:
aggregateByKey替代groupByKey + 聚合combineByKey自定义预聚合逻辑
方案4: 提高并行度
// 方式1: 全局配置
spark.sql.shuffle.partitions = 1000
// 方式2: 代码指定
rdd.reduceByKey(_ + _, 1000)
// 方式3: 重分区
rdd.repartition(1000)
原理: 分区多了,单个分区数据少了
200分区: 每分区 5000条(某倾斜分区50万条)
1000分区: 每分区 1000条(倾斜分区10万条)
注意:
- 分区太多会增加调度开销
- 推荐: 核心数 × 2~4
方案5: 自定义分区器
class CustomPartitioner(partitions: Int) extends Partitioner {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = key match {
case "热点key" =>
// 热点key随机分配到多个分区
50 + Random.nextInt(10) // 分区50-59
case k =>
Math.abs(k.hashCode() % 50) // 其他key用0-49
}
}
rdd.partitionBy(new CustomPartitioner(100))
.mapPartitions(iter => {
// 分区内处理逻辑
iter
})
适用: 已知热点key,需要精确控制分区策略
方案6: 广播Join(小表Join大表)⭐
// ❌ 问题:大表join小表
bigRDD.join(smallRDD) // Shuffle两个表
// ✅ 优化:广播小表
val smallMap = sc.broadcast(
smallRDD.collectAsMap() // 小表收集到Driver
)
val result = bigRDD.map { case (k, v) =>
val v2 = smallMap.value.getOrElse(k, null)
(k, (v, v2))
}.filter(_._2._2 != null) // 过滤未匹配的
条件: 小表 < 几百MB(可通过 spark.sql.autoBroadcastJoinThreshold 调整)
SQL自动优化:
-- Spark SQL 会自动广播小表
SELECT /*+ BROADCAST(small) */ *
FROM big JOIN small ON big.id = small.id
方案7: 两阶段聚合
// 适用:倾斜key无法过滤,且数据量巨大
val saltNum = 100
// 第一阶段:加盐局部聚合
val stage1 = rdd.map { case (k, v) =>
val salt = Random.nextInt(saltNum)
((k, salt), v)
}.reduceByKey(_ + _) // 局部聚合
// 第二阶段:去盐全局聚合
val result = stage1.map { case ((k, salt), v) =>
(k, v)
}.reduceByKey(_ + _) // 全局聚合
效果: 将单个热点key的计算分散到多个Task
方案8: 采样倾斜key单独处理
// 1. 采样找出倾斜key
val sample = rdd.sample(false, 0.1)
val skewedKeys = sample
.map(x => (x._1, 1))
.reduceByKey(_ + _)
.filter(_._2 > threshold) // 阈值
.map(_._1)
.collect()
.toSet
// 2. 分离倾斜数据
val skewedRDD = rdd.filter(x => skewedKeys.contains(x._1))
val normalRDD = rdd.filter(x => !skewedKeys.contains(x._1))
// 3. 倾斜数据加盐处理
val skewedResult = skewedRDD
.map { case (k, v) => ((k, Random.nextInt(10)), v) }
.reduceByKey(_ + _)
.map { case ((k, _), v) => (k, v) }
.reduceByKey(_ + _)
// 4. 正常数据正常处理
val normalResult = normalRDD.reduceByKey(_ + _)
// 5. 合并结果
val result = skewedResult.union(normalResult)
优点: 只对倾斜数据加盐,减少不必要开销
倾斜检测方法
方法1: 查看分区数据分布
val partitionSizes = rdd.mapPartitionsWithIndex { (idx, iter) =>
Iterator((idx, iter.size))
}.collect().sortBy(-_._2)
partitionSizes.foreach(println)
// 输出:
// (2, 500000) ⚠️ 倾斜分区
// (0, 1000)
// (1, 1200)
// (3, 900)
方法2: 查看key分布
val keyDistribution = rdd
.map(x => (x._1, 1))
.reduceByKey(_ + _)
.sortBy(_._2, ascending = false)
.take(20)
keyDistribution.foreach(println)
// 输出:
// ("热点key", 1000000) ⚠️
// ("正常key1", 1200)
// ("正常key2", 1100)
方法3: Spark UI监控
访问 http://driver:4040
关键指标:
-
Stage页面:
- 查看Task执行时间分布
- 倾斜表现: 某些Task时间远超平均值
-
Shuffle Read/Write:
- 查看每个Task的Shuffle数据量
- 倾斜表现: 某Task数据量特别大
-
Executor页面:
- 查看GC时间
- 倾斜表现: 某Executor频繁GC
方法4: 采样分析
// 采样10%数据分析
val sampleData = rdd.sample(false, 0.1)
.map(x => (x._1, 1))
.reduceByKey(_ + _)
.collect()
// 计算统计信息
val counts = sampleData.map(_._2)
val avg = counts.sum / counts.length
val max = counts.max
if (max > avg * 10) {
println(s"⚠️ 检测到数据倾斜: 最大值=${max}, 平均值=${avg}")
}
优化总结
算子选择优化
| 操作 | ❌ 差 | ✅ 好 | 原因 |
|---|---|---|---|
| 聚合 | groupByKey | reduceByKey | Map端预聚合 |
| 去重 | groupBy().keys | distinct() | 专用算子优化 |
| 计数 | groupByKey().count | countByKey() | 避免拉取所有数据 |
| Join | 大表join大表 | 广播join | 避免双向Shuffle |
参数调优 Checklist
// 1. 并行度设置
spark.sql.shuffle.partitions = 200~500 // 根据数据量调整
spark.default.parallelism = CPU核心数 × 2~4
// 2. 内存优化
spark.executor.memory = 4g~8g
spark.executor.memoryOverhead = executor内存 × 0.1
// 3. 压缩配置
spark.shuffle.compress = true
spark.io.compression.codec = snappy
// 4. 网络优化
spark.reducer.maxSizeInFlight = 48m
spark.network.timeout = 120s
// 5. Shuffle优化
spark.shuffle.file.buffer = 32k
spark.shuffle.sort.bypassMergeThreshold = 200
数据倾斜方案选择指南
| 场景 | 推荐方案 | 优先级 |
|---|---|---|
| 少量已知热点key | 过滤或单独处理 | ⭐⭐⭐ |
| 未知热点key | 加盐打散 | ⭐⭐⭐ |
| 可预聚合场景 | reduceByKey 替代 groupByKey | ⭐⭐⭐ |
| 小表join大表 | 广播join | ⭐⭐⭐ |
| 轻度倾斜 | 提高并行度 | ⭐⭐ |
| 严重倾斜 | 两阶段聚合 + 加盐 | ⭐⭐⭐ |
| 特定业务场景 | 自定义分区器 | ⭐⭐ |
优化步骤流程
1. 检测倾斜
↓ Spark UI + 采样分析
2. 定位原因
↓ 分析key分布
3. 选择方案
↓ 根据场景选择
4. 实施优化
↓ 代码/参数调整
5. 验证效果
↓ 对比性能指标
核心记忆口诀
Shuffle优化四原则:
- 能避免就避免 - 用窄依赖算子替代
- 能减少就减少 - Map端预聚合
- 能打散就打散 - 加盐处理热点
- 能广播就广播 - 小表不Shuffle
数据倾斜处理三步:
- 检测 - UI监控 + 采样分析
- 定位 - 找到热点key
- 优化 - 加盐/广播/自定义分区
实战案例
案例1: 用户行为分析
问题: 统计每个用户的行为次数,某明星用户数据量巨大
// ❌ 原始代码
userActions.groupByKey().mapValues(_.size)
// 倾斜 + 低效
// ✅ 优化后
// 方案1: 直接用countByKey
userActions.map(x => (x.userId, 1)).countByKey()
// 方案2: 加盐处理
val saltNum = 10
userActions
.map(x => ((x.userId, Random.nextInt(saltNum)), 1))
.reduceByKey(_ + _)
.map { case ((userId, _), count) => (userId, count) }
.reduceByKey(_ + _)
案例2: 订单金额汇总
问题: 按商品ID汇总销售额,爆款商品倾斜
// ❌ 原始代码
orders.map(x => (x.productId, x.amount))
.groupByKey()
.mapValues(_.sum)
// ✅ 优化方案1: 预聚合
orders.map(x => (x.productId, x.amount))
.reduceByKey(_ + _) // Map端预聚合
// ✅ 优化方案2: 采样 + 分离处理
val hotProducts = orders.sample(false, 0.01)
.map(x => (x.productId, 1))
.reduceByKey(_ + _)
.filter(_._2 > 1000)
.keys.collect().toSet
val hotOrders = orders.filter(x => hotProducts.contains(x.productId))
.map(x => ((x.productId, Random.nextInt(10)), x.amount))
.reduceByKey(_ + _)
.map { case ((id, _), amt) => (id, amt) }
.reduceByKey(_ + _)
val normalOrders = orders.filter(x => !hotProducts.contains(x.productId))
.map(x => (x.productId, x.amount))
.reduceByKey(_ + _)
val result = hotOrders.union(normalOrders).reduceByKey(_ + _)
案例3: 大表Join小表
问题: 订单表(10亿) join 用户表(100万)
// ❌ 原始代码
ordersRDD.join(usersRDD) // 双向Shuffle
// ✅ 优化: 广播小表
val userMap = sc.broadcast(
usersRDD.collectAsMap()
)
ordersRDD.map { order =>
val user = userMap.value.get(order.userId)
(order, user)
}.filter(_._2.isDefined)
// ✅ SQL自动优化
spark.sql("""
SELECT /*+ BROADCAST(users) */ *
FROM orders JOIN users ON orders.user_id = users.id
""")
附录: 常用监控命令
// 查看RDD分区数
rdd.getNumPartitions
// 查看分区器
rdd.partitioner
// 查看每个分区元素数量
rdd.glom().map(_.length).collect()
// 缓存RDD
rdd.cache()
rdd.persist(StorageLevel.MEMORY_AND_DISK)
// 查看血统
rdd.toDebugString
// 查看执行计划(DataFrame/SQL)
df.explain()
df.explain(true) // 详细信息
参考资源
- Spark官方文档 - Tuning Guide
- Spark SQL Performance Tuning
- 《Spark权威指南》
- 《Spark性能优化指南》
985

被折叠的 条评论
为什么被折叠?



