1. cache()
方法
cache()
是 Spark 中最简单的缓存方法,它将 RDD 或 DataFrame 持久化到内存中。默认情况下,cache()
使用的是 MEMORY_ONLY
存储级别。
示例:缓存 RDD
scala
复制
val rdd = sc.parallelize(List(1, 2, 3, 4, 5))
rdd.cache()
// 触发缓存计算
rdd.count()
示例:缓存 DataFrame
scala
复制
val df = spark.read.json("path/to/json/file.json")
df.cache()
// 触发缓存计算
df.count()
2. persist()
方法
persist()
是一个更通用的缓存方法,它允许你指定缓存的存储级别。Spark 提供了多种存储级别,可以根据你的需求选择合适的缓存策略。
存储级别
Spark 提供了以下存储级别:
-
MEMORY_ONLY
:将数据存储在内存中,如果内存不足,则会丢弃旧数据。 -
MEMORY_AND_DISK
:将数据存储在内存中,如果内存不足,则将剩余数据存储到磁盘上。 -
MEMORY_ONLY_SER
:将数据序列化后存储在内存中,节省内存空间。 -
MEMORY_AND_DISK_SER
:将数据序列化后存储在内存中,如果内存不足,则将剩余数据存储到磁盘上。 -
DISK_ONLY
:仅将数据存储在磁盘上。 -
OFF_HEAP
:将数据存储在堆外内存中(需要配置spark.memory.offHeap.enabled
为true
)。
示例:使用 persist()
方法
scala
复制
val rdd = sc.parallelize(List(1, 2, 3, 4, 5))
rdd.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
// 触发缓存计算
rdd.count()
示例:缓存 DataFrame 并指定存储级别
scala
复制
val df = spark.read.json("path/to/json/file.json")
df.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
// 触发缓存计算
df.count()
3. 缓存的释放
缓存的数据会占用内存或磁盘空间,如果不再需要缓存的数据,可以手动释放它们。
释放 RDD 缓存
scala
复制
rdd.unpersist()
释放 DataFrame 缓存
scala
复制
df.unpersist()
4. 缓存的适用场景
4.1 多次使用相同数据
如果你的程序中同一个 RDD 或 DataFrame 被多次使用,缓存可以显著提高性能。例如:
scala
复制
val df = spark.read.json("path/to/json/file.json")
df.cache()
val result1 = df.filter($"age" > 20).count()
val result2 = df.filter($"age" < 20).count()
4.2 大规模迭代计算
在机器学习算法中,通常需要多次迭代处理相同的数据集。缓存可以避免重复读取数据,提高迭代效率。
4.3 数据预处理
如果数据预处理步骤复杂且耗时,可以将预处理后的数据缓存起来,以便后续使用。
5. 缓存的注意事项
5.1 缓存策略的选择
-
如果内存充足,可以选择
MEMORY_ONLY
或MEMORY_ONLY_SER
。 -
如果内存有限,可以选择
MEMORY_AND_DISK
或MEMORY_AND_DISK_SER
。 -
如果数据量非常大,可以选择
DISK_ONLY
,但会牺牲一定的性能。
5.2 缓存的大小
缓存的数据会占用内存或磁盘空间,因此需要合理控制缓存的大小,避免占用过多资源。
5.3 缓存的释放
缓存的数据不会自动释放,因此在不再需要时,应该手动调用 unpersist()
方法释放缓存。
6. 示例代码
以下是一个完整的示例代码,展示了如何使用 cache()
和 persist()
方法缓存 RDD 和 DataFrame,并释放缓存。
scala
复制
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
val spark = SparkSession.builder()
.appName("Spark Cache Example")
.master("local[*]")
.getOrCreate()
// 创建 RDD
val rdd = sc.parallelize(List(1, 2, 3, 4, 5))
// 缓存 RDD
rdd.cache()
// 触发缓存计算
rdd.count()
// 使用缓存的 RDD
val result1 = rdd.map(_ * 2).collect()
println(result1.mkString(", "))
// 释放 RDD 缓存
rdd.unpersist()
// 创建 DataFrame
val df = spark.read.json("path/to/json/file.json")
// 缓存 DataFrame
df.persist(StorageLevel.MEMORY_AND_DISK)
// 触发缓存计算
df.count()
// 使用缓存的 DataFrame
val result2 = df.filter($"age" > 20).collect()
result2.foreach(println)
// 释放 DataFrame 缓存
df.unpersist()
// 停止 SparkSession
spark.stop()