以下是 Spark MLlib 中常用算法的核心解析及实战要点,涵盖分类、回归、聚类、协同过滤四大场景:
一、分类算法
1. Logistic Regression (逻辑回归)
- 场景:二分类/多分类(如欺诈检测、广告点击预测)
- Spark 实现特点:
- 支持 L1/L2/ElasticNet 正则化
- 分布式优化算法:L-BFGS/OWL-QN
- 代码示例:
import org.apache.spark.ml.classification.LogisticRegression val lr = new LogisticRegression() .setMaxIter(100) // 最大迭代次数 .setRegParam(0.01) // 正则化强度 .setElasticNetParam(0.5) // L1/L2混合比例 (0=L2, 1=L1) val model = lr.fit(trainData) val predictions = model.transform(testData)
2. Decision Trees (决策树)
- 场景:可解释性强的分类(如信贷风险评估)
- Spark 优化:
- 分布式计算信息增益(Gini/Entropy)
- 支持类别特征自动处理
- 关键参数:
import org.apache.spark.ml.classification.DecisionTreeClassifier val dt = new DecisionTreeClassifier() .setMaxDepth(5) // 树最大深度 .setImpurity("gini") // 分裂标准 (gini/entropy) .setMinInfoGain(0.01) // 最小信息增益阈值
3. Random Forest (随机森林)
- 场景:高精度分类(如图像识别、医疗诊断)
- Spark 优势:
- 并行训练多棵决策树
- 内置特征重要性评估
- 代码示例:
import org.apache.spark.ml.classification.RandomForestClassifier val rf = new RandomForestClassifier() .setNumTrees(50) // 树的数量 .setFeatureSubsetStrategy("sqrt") // 特征采样策略 .setSubsamplingRate(0.8) // 样本采样率 val model = rf.fit(trainData) println(model.featureImportances) // 输出特征重要性
二、回归算法
Linear Regression (线性回归)
- 场景:连续值预测(如房价预测、销量预估)
- Spark 特性:
- 支持带正则化的最小二乘
- 可处理大于内存的数据集
- 实战配置:
import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression() .setSolver("l-bfgs") // 优化器 (l-bfgs/normal) .setLoss("huber") // 损失函数 (squared/huber) .setMaxIter(200)
三、聚类算法
K-Means
- 场景:无监督分组(如客户分群、异常检测)
- Spark 优化:
- K-Means|| 并行初始化算法
- 支持余弦相似度(文本聚类)
- 代码示例:
import org.apache.spark.ml.clustering.KMeans val kmeans = new KMeans() .setK(3) // 聚类中心数 .setSeed(42L) // 随机种子 .setDistanceMeasure("cosine") // 距离度量 (euclidean/cosine) val model = kmeans.fit(featureData) model.clusterCenters.foreach(println) // 打印聚类中心
四、协同过滤
ALS (交替最小二乘法)
- 场景:推荐系统(如电影推荐、商品推荐)
- Spark 创新:
- 隐式反馈支持(浏览/点击行为)
- 冷启动处理(
coldStartStrategy
)
- 实战配置:
import org.apache.spark.ml.recommendation.ALS val als = new ALS() .setRank(10) // 隐因子维度 .setMaxIter(15) .setImplicitPrefs(true) // 使用隐式反馈数据 .setColdStartStrategy("drop") // 处理未知用户/物品 .setUserCol("userId") .setItemCol("movieId") .setRatingCol("rating") val model = als.fit(ratingsData) val recommendations = model.recommendForAllUsers(5) // 为每个用户推荐Top5
关键注意事项
-
数据准备要求:
- 分类/回归:
label
列 +features
向量 - ALS:强制要求
userId
,itemId
,rating
列名
- 分类/回归:
-
性能调优技巧:
// 所有算法通用 .setMaxIter(100) // 增加迭代次数 .setRegParam(0.01) // 调整正则化避免过拟合 .setTol(1e-6) // 收敛阈值 // 树模型特有 .setMaxBins(32) // 增加离散化桶数提升精度
-
部署最佳实践:
// 保存完整Pipeline模型(含特征工程) pipelineModel.write.overwrite().save("/models/rec_sys") // 加载模型进行批量预测 val sameModel = PipelineModel.load("/models/rec_sys") val realTimePredictions = sameModel.transform(streamingDF)
算法选择指南
问题类型 | 首选算法 | 场景优势 |
---|---|---|
二分类 | Logistic Regression | 高维稀疏数据(文本分类) |
多分类 | Random Forest | 混合特征类型处理能力强 |
回归预测 | Linear Regression | 特征与目标线性关系明显时 |
用户分群 | K-Means | 海量数据快速聚类 |
推荐系统 | ALS | 显式/隐式反馈协同过滤 |
提示:实际项目中常组合使用,如用K-Means做用户分群后,对不同群体使用独立ALS模型推荐