Apache Spark Java 示例:机器学习(逻辑回归)
本文将详细介绍如何使用 Apache Spark 的 MLlib 库实现逻辑回归模型,用于解决二分类问题。我们将使用经典的乳腺癌预测数据集,展示从数据准备到模型部署的完整流程。
项目概述
我们将构建一个预测乳腺癌良恶性的分类模型,包含以下步骤:
- 数据加载与探索
- 数据预处理与特征工程
- 模型训练与评估
- 超参数调优
- 模型部署与预测
完整实现代码
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.Arrays;
import java.util.List;
public class BreastCancerClassification {
public static void main(String[] args) {
// 1. 创建SparkSession
SparkSession spark = SparkSession.builder()
.appName("Breast Cancer Classification")
.master("local[*]") // 本地模式,生产环境使用集群模式
.config("spark.sql.shuffle.partitions", "8")
.getOrCreate();
try {
// 2. 加载数据集
Dataset<Row> data = loadData(spark);
// 3. 数据探索
exploreData(data);
// 4. 数据预处理
Dataset<Row> preprocessedData = preprocessData(data);
// 5. 划分训练集和测试集
Dataset<Row>[] splits = preprocessedData.randomSplit(new double[]{0.8, 0.2}, 42);
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// 6. 构建逻辑回归模型
LogisticRegression lr = new LogisticRegression()
.setLabelCol("label")
.setFeaturesCol("scaledFeatures")
.setMaxIter(100)
.setRegParam(0.01)
.setElasticNetParam(0.8);
// 7. 创建Pipeline
Pipeline pipeline = new Pipeline().setStages(
new PipelineStage[]{lr}
);
// 8. 训练模型
PipelineModel model = pipeline.fit(trainingData);
// 9. 模型评估
evaluateModel(model, testData);
// 10. 超参数调优
PipelineModel bestModel = hyperparameterTuning(pipeline, trainingData);
// 11. 评估优化后模型
evaluateModel(bestModel, testData);
// 12. 模型部署与预测
deployModel(bestModel, testData);
} catch (Exception e) {
System.err.println("程序执行出错: " + e.getMessage());
e.printStackTrace();
} finally {
spark.stop();
}
}
/**
* 加载乳腺癌数据集
*/
private static Dataset<Row> loadData(SparkSession spark) {
// 数据集结构: ID,诊断结果(M=恶性,B=良性),30个特征
StructType schema = new StructType(new StructField[]{
DataTypes.createStructField("id", DataTypes.IntegerType, false),
DataTypes.createStructField("diagnosis", DataTypes.StringType, false),
DataTypes.createStructField("radius_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("texture_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("perimeter_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("area_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("smoothness_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("compactness_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("concavity_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("concave_points_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("symmetry_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("fractal_dimension_mean", DataTypes.DoubleType, false),
DataTypes.createStructField("radius_se", DataTypes.DoubleType, false),
DataTypes.createStructField("texture_se", DataTypes.DoubleType, false),
DataTypes.createStructField("perimeter_se", DataTypes.DoubleType, false),
DataTypes.createStructField("area_se", DataTypes.DoubleType, false),
DataTypes.createStructField("smoothness_se", DataTypes.DoubleType, false),
DataTypes.createStructField("compactness_se", DataTypes.DoubleType, false),
DataTypes.createStructField("concavity_se", DataTypes.DoubleType, false),
DataTypes.createStructField("concave_points_se", DataTypes.DoubleType, false),
DataTypes.createStructField("symmetry_se", DataTypes.DoubleType, false),
DataTypes.createStructField("fractal_dimension_se", DataTypes.DoubleType, false),
DataTypes.createStructField("radius_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("texture_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("perimeter_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("area_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("smoothness_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("compactness_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("concavity_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("concave_points_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("symmetry_worst", DataTypes.DoubleType, false),
DataTypes.createStructField("fractal_dimension_worst", DataTypes.DoubleType, false)
});
// 从CSV文件加载数据
Dataset<Row> data = spark.read()
.format("csv")
.option("header", "true")
.option("inferSchema", "false")
.schema(schema)
.load("data/breast_cancer.csv");
return data;
}
/**
* 数据探索分析
*/
private static void exploreData(Dataset<Row> data) {
System.out.println("================ 数据探索 ================");
// 1. 显示数据摘要
System.out.println("数据集结构:");
data.printSchema();
// 2. 显示样本统计信息
System.out.println("\n数据摘要:");
data.describe().show();
// 3. 类别分布
System.out.println("\n诊断结果分布:");
data.groupBy("diagnosis").count().show();
// 4. 缺失值检查
System.out.println("\n缺失值统计:");
for (String col : data.columns()) {
long missingCount = data.filter(data.col(col).isNull()).count();
if (missingCount > 0) {
System.out.println(col + ": " + missingCount + " 个缺失值");
}
}
// 5. 特征相关性分析(示例)
System.out.println("\n特征相关性示例:");
data.select("radius_mean", "perimeter_mean", "area_mean").show(5);
}
/**
* 数据预处理
*/
private static Dataset<Row> preprocessData(Dataset<Row> data) {
System.out.println("\n================ 数据预处理 ================");
// 1. 转换目标变量:M(恶性)=1, B(良性)=0
Dataset<Row> labeledData = data.withColumn("label",
functions.when(data.col("diagnosis").equalTo("M"), 1.0).otherwise(0.0)
);
// 2. 选择特征列
List<String> featureCols = Arrays.asList(
"radius_mean", "texture_mean", "perimeter_mean", "area_mean",
"smoothness_mean", "compactness_mean", "concavity_mean", "concave_points_mean",
"symmetry_mean", "fractal_dimension_mean", "radius_se", "texture_se",
"perimeter_se", "area_se", "smoothness_se", "compactness_se", "concavity_se",
"concave_points_se", "symmetry_se", "fractal_dimension_se", "radius_worst",
"texture_worst", "perimeter_worst", "area_worst", "smoothness_worst",
"compactness_worst", "concavity_worst", "concave_points_worst", "symmetry_worst",
"fractal_dimension_worst"
);
// 3. 创建特征向量
VectorAssembler assembler = new VectorAssembler()
.setInputCols(featureCols.toArray(new String[0]))
.setOutputCol("features");
// 4. 特征标准化
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithStd(true)
.setWithMean(true);
// 5. 创建预处理Pipeline
Pipeline preprocessingPipeline = new Pipeline()
.setStages(new PipelineStage[]{assembler, scaler});
// 6. 执行预处理
PipelineModel preprocessingModel = preprocessingPipeline.fit(labeledData);
Dataset<Row> preprocessedData = preprocessingModel.transform(labeledData);
// 7. 选择需要的列
preprocessedData = preprocessedData.select("label", "scaledFeatures");
System.out.println("预处理后的数据示例:");
preprocessedData.show(5);
return preprocessedData;
}
/**
* 模型评估
*/
private static void evaluateModel(PipelineModel model, Dataset<Row> testData) {
System.out.println("\n================ 模型评估 ================");
// 1. 在测试集上进行预测
Dataset<Row> predictions = model.transform(testData);
// 2. 显示预测结果
System.out.println("预测结果示例:");
predictions.select("label", "prediction", "probability").show(10);
// 3. 计算评估指标
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction");
double auc = evaluator.evaluate(predictions);
System.out.printf("模型AUC = %.4f%n", auc);
// 4. 计算准确率
long correct = predictions.filter("label == prediction").count();
long total = predictions.count();
double accuracy = (double) correct / total;
System.out.printf("准确率 = %.4f (%d/%d)%n", accuracy, correct, total);
// 5. 混淆矩阵
Dataset<Row> confusionMatrix = predictions.groupBy("label", "prediction").count();
System.out.println("混淆矩阵:");
confusionMatrix.show();
// 6. 计算精确率、召回率和F1分数
long tp = predictions.filter("label = 1 AND prediction = 1").count();
long fp = predictions.filter("label = 0 AND prediction = 1").count();
long fn = predictions.filter("label = 1 AND prediction = 0").count();
double precision = (double) tp / (tp + fp);
double recall = (double) tp / (tp + fn);
double f1 = 2 * (precision * recall) / (precision + recall);
System.out.printf("精确率 = %.4f%n", precision);
System.out.printf("召回率 = %.4f%n", recall);
System.out.printf("F1分数 = %.4f%n", f1);
}
/**
* 超参数调优
*/
private static PipelineModel hyperparameterTuning(Pipeline pipeline, Dataset<Row> trainingData) {
System.out.println("\n================ 超参数调优 ================");
// 1. 创建参数网格
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(pipeline.getStages()[0].asInstanceOf[LogisticRegression].regParam(),
new double[]{0.001, 0.01, 0.1, 1.0})
.addGrid(pipeline.getStages()[0].asInstanceOf[LogisticRegression].elasticNetParam(),
new double[]{0.0, 0.5, 1.0})
.addGrid(pipeline.getStages()[0].asInstanceOf[LogisticRegression].maxIter(),
new int[]{50, 100, 200})
.build();
// 2. 创建交叉验证器
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction");
CrossValidator crossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(5) // 5折交叉验证
.setParallelism(4); // 并行度
// 3. 运行交叉验证
CrossValidatorModel cvModel = crossValidator.fit(trainingData);
// 4. 获取最佳模型
PipelineModel bestModel = (PipelineModel) cvModel.bestModel();
// 5. 输出最佳参数
LogisticRegression bestLR = (LogisticRegression) bestModel.stages()[0];
System.out.println("最佳参数:");
System.out.println(" - regParam: " + bestLR.getRegParam());
System.out.println(" - elasticNetParam: " + bestLR.getElasticNetParam());
System.out.println(" - maxIter: " + bestLR.getMaxIter());
return bestModel;
}
/**
* 模型部署与预测
*/
private static void deployModel(PipelineModel model, Dataset<Row> testData) {
System.out.println("\n================ 模型部署 ================");
// 1. 保存模型
model.write().overwrite().save("models/breast_cancer_lr_model");
System.out.println("模型已保存到 models/breast_cancer_lr_model");
// 2. 加载模型
PipelineModel loadedModel = PipelineModel.load("models/breast_cancer_lr_model");
// 3. 模拟新数据预测
System.out.println("\n模拟新数据预测:");
// 创建新数据 (实际应用中从外部源获取)
List<Row> newData = Arrays.asList(
RowFactory.create(1.0, Vectors.dense(17.99, 10.38, 122.8, 1001.0, 0.1184, 0.2776, 0.3001, 0.1471, 0.2419, 0.07871, 1.095, 0.9053, 8.589, 153.4, 0.006399, 0.04904, 0.05373, 0.01587, 0.03003, 0.006193, 25.38, 17.33, 184.6, 2019.0, 0.1622, 0.6656, 0.7119, 0.2654, 0.4601, 0.1189)),
RowFactory.create(0.0, Vectors.dense(13.54, 14.36, 87.46, 566.3, 0.09779, 0.08129, 0.06664, 0.04781, 0.1885, 0.05766, 0.2699, 0.7886, 2.058, 23.56, 0.008462, 0.0146, 0.02387, 0.01315, 0.0198, 0.0023, 15.11, 19.26, 99.7, 711.2, 0.144, 0.1773, 0.239, 0.1288, 0.2977, 0.07259))
);
StructType schema = new StructType(new StructField[]{
DataTypes.createStructField("label", DataTypes.DoubleType, false),
DataTypes.createStructField("scaledFeatures", DataTypes.createVectorType(30), false)
});
Dataset<Row> newDataDF = spark.createDataFrame(newData, schema);
// 进行预测
Dataset<Row> predictions = loadedModel.transform(newDataDF);
// 显示预测结果
predictions.select("label", "prediction", "probability").show();
// 4. 解释预测结果
System.out.println("\n预测结果解释:");
for (Row row : predictions.collectAsList()) {
double label = row.getDouble(0);
double prediction = row.getDouble(1);
Vector probability = row.getAs(2);
String diagnosis = prediction == 1.0 ? "恶性(M)" : "良性(B)";
double confidence = probability.toArray()[prediction == 1.0 ? 1 : 0];
System.out.printf("真实诊断: %s, 预测诊断: %s, 置信度: %.2f%%%n",
label == 1.0 ? "恶性(M)" : "良性(B)",
diagnosis,
confidence * 100
);
}
}
}
机器学习核心概念详解
1. 逻辑回归原理
逻辑回归是一种用于二分类问题的统计方法,通过Sigmoid函数将线性回归的输出映射到(0,1)区间:
P(y=1∣x)=11+e−(wTx+b) P(y=1|x) = \frac{1}{1 + e^{-(w^Tx + b)}} P(y=1∣x)=1+e−(wTx+b)1
其中:
- www 是权重向量
- bbb 是偏置项
- xxx 是特征向量
2. Spark MLlib 架构
3. 特征工程关键技术
A. 特征向量化
VectorAssembler assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol("features");
- 将多个数值列组合成单个特征向量
- 是大多数Spark ML算法的输入要求
B. 特征标准化
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithStd(true)
.setWithMean(true);
- 将特征缩放到均值为0,标准差为1
- 提高模型收敛速度和性能
- 对基于距离的算法尤为重要
C. 其他特征处理技术
-
缺失值处理:
Imputer imputer = new Imputer() .setInputCols(featureCols) .setOutputCols(featureCols) .setStrategy("mean"); // 或 "median"
-
类别特征编码:
StringIndexer indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex"); OneHotEncoder encoder = new OneHotEncoder() .setInputCol("categoryIndex") .setOutputCol("categoryVec");
-
特征选择:
ChiSqSelector selector = new ChiSqSelector() .setFeaturesCol("features") .setLabelCol("label") .setNumTopFeatures(20);
4. 模型评估指标
指标 | 公式 | 意义 |
---|---|---|
准确率 | TP+TNTP+TN+FP+FN\frac{TP+TN}{TP+TN+FP+FN}TP+TN+FP+FNTP+TN | 正确预测的比例 |
精确率 | TPTP+FP\frac{TP}{TP+FP}TP+FPTP | 预测为正例中实际为正的比例 |
召回率 | TPTP+FN\frac{TP}{TP+FN}TP+FNTP | 实际为正例中被正确预测的比例 |
F1分数 | 2×Precision×RecallPrecision+Recall2 \times \frac{Precision \times Recall}{Precision + Recall}2×Precision+RecallPrecision×Recall | 精确率和召回率的调和平均 |
AUC | ROC曲线下面积 | 模型区分正负样本的能力 |
5. 超参数调优技术
A. 网格搜索
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.01, 0.1, 1.0})
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.build();
- 指定参数值的组合
- 穷举搜索所有组合
B. 交叉验证
CrossValidator crossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(5); // 5折交叉验证
- 将数据分为K份
- 轮流使用K-1份训练,1份验证
- 减少过拟合风险
C. 训练-验证拆分
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setTrainRatio(0.8); // 80%训练,20%验证
- 更快的替代方案
- 适合大型数据集
生产环境最佳实践
1. 集群配置优化
spark-submit \
--class BreastCancerClassification \
--master yarn \
--deploy-mode cluster \
--num-executors 10 \
--executor-cores 4 \
--executor-memory 8G \
--conf spark.sql.shuffle.partitions=100 \
--conf spark.memory.fraction=0.8 \
--conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
your-application.jar
2. 特征存储
使用特征存储系统管理特征工程:
// 保存特征元数据
FeatureStore.saveFeatureMetadata(featureCols, "breast_cancer_features");
// 加载特征
List<String> featureCols = FeatureStore.loadFeatureMetadata("breast_cancer_features");
3. 模型监控
// 记录模型指标
ModelMetricsTracker.trackModelMetrics(
modelName: "breast_cancer_lr",
version: "1.0",
metrics: Map.of(
"AUC", auc,
"Accuracy", accuracy,
"Precision", precision,
"Recall", recall
)
);
// 检测模型漂移
double driftScore = ModelMonitor.calculateDriftScore(
currentData: testData,
trainingData: trainingData,
featureCols: featureCols
);
4. 模型服务化
A. 实时API服务
// 使用Spark ML Serving
MLServer mlServer = new MLServer()
.setPort(8080)
.setModelPath("models/breast_cancer_lr_model")
.start();
// 请求示例: POST /predict
// {"features": [17.99, 10.38, 122.8, ...]}
B. 批量预测
// 加载新数据
Dataset<Row> newData = spark.read().parquet("hdfs:///new_data");
// 加载模型
PipelineModel model = PipelineModel.load("models/breast_cancer_lr_model");
// 进行预测
Dataset<Row> predictions = model.transform(newData);
// 保存结果
predictions.write().parquet("hdfs:///predictions");
扩展应用场景
1. 多分类问题
// 使用多项逻辑回归
LogisticRegression lr = new LogisticRegression()
.setFamily("multinomial")
.setLabelCol("label")
.setFeaturesCol("features");
2. 集成学习
// 随机森林
RandomForestClassifier rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(100);
// 梯度提升树
GBTClassifier gbt = new GBTClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setMaxIter(50);
3. 深度学习
// 使用DeepLearning4J集成
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam())
.list()
.layer(new DenseLayer.Builder().nIn(numFeatures).nOut(50).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
.nIn(50).nOut(2).activation(Activation.SIGMOID).build())
.build();
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(spark, conf)
.setFeaturesCol("features")
.setLabelCol("label");
4. 异常检测
// 使用隔离森林
IsolationForest iforest = new IsolationForest()
.setFeaturesCol("features")
.setPredictionCol("prediction")
.setContamination(0.01);
性能优化策略
1. 数据并行化
// 增加分区数
data = data.repartition(200);
// 使用高效数据格式
data.write().parquet("hdfs:///data.parquet");
2. 内存管理
// 缓存中间结果
trainingData.persist(StorageLevel.MEMORY_AND_DISK());
// 使用堆外内存
conf.set("spark.memory.offHeap.enabled", "true");
conf.set("spark.memory.offHeap.size", "2g");
3. 算法优化
// 使用特征选择减少维度
ChiSqSelector selector = new ChiSqSelector()
.setNumTopFeatures(20)
.setFeaturesCol("features")
.setLabelCol("label")
.setOutputCol("selectedFeatures");
4. 分布式训练加速
// 使用AllReduce分布式优化
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(spark, conf)
.setTrainingMaster(new ParameterAveragingTrainingMaster.Builder(4)
.averagingFrequency(5)
.workerPrefetchNumBatches(2)
.build());
医疗领域应用扩展
1. 患者风险分层
// 根据预测概率分层
predictions.withColumn("risk_level",
when(col("probability").getItem(1).gt(0.8), "高风险")
.when(col("probability").getItem(1).gt(0.5), "中风险")
.otherwise("低风险")
);
2. 特征重要性分析
// 获取逻辑回归特征权重
LogisticRegressionModel lrModel = (LogisticRegressionModel) model.stages()[0];
Vector coefficients = lrModel.coefficients();
// 显示特征重要性
for (int i = 0; i < featureCols.size(); i++) {
System.out.printf("%s: %.4f%n", featureCols.get(i), coefficients.apply(i));
}
3. 模型解释
// 使用SHAP值解释模型预测
SHAPExplainer explainer = new SHAPExplainer(model);
Vector explanation = explainer.explain(features);
// 显示特征贡献
Map<String, Double> featureContributions = new HashMap<>();
for (int i = 0; i < featureCols.size(); i++) {
featureContributions.put(featureCols.get(i), explanation.apply(i));
}
总结
通过这个乳腺癌分类项目,我们展示了使用Spark MLlib实现逻辑回归的完整流程:
- 数据准备:加载和探索乳腺癌数据集
- 特征工程:特征向量化和标准化
- 模型训练:构建和训练逻辑回归模型
- 模型评估:使用多种指标评估模型性能
- 超参数调优:通过交叉验证优化模型参数
- 模型部署:保存模型并进行新数据预测
Spark MLlib的优势:
- 分布式计算:处理大规模数据集
- 完整流水线:支持端到端机器学习流程
- 丰富算法:提供多种分类、回归和聚类算法
- 生产就绪:支持模型部署和监控
在实际医疗应用中,这种技术可以扩展到:
- 疾病早期诊断
- 患者风险评估
- 治疗效果预测
- 医疗资源优化
通过结合领域知识和机器学习技术,Spark为医疗健康领域提供了强大的数据分析能力,帮助提高诊断准确性和治疗效果。