Azure/mmlspark 项目实战:使用随机搜索进行乳腺癌分类模型的超参数调优

Azure/mmlspark 项目实战:使用随机搜索进行乳腺癌分类模型的超参数调优

SynapseML SynapseML 项目地址: https://gitcode.com/gh_mirrors/mm/mmlspark

前言

在机器学习项目中,选择合适的模型超参数对模型性能有着至关重要的影响。Azure/mmlspark项目提供了强大的自动化机器学习工具,可以帮助我们高效地进行超参数调优。本文将详细介绍如何使用mmlspark中的随机搜索方法对乳腺癌分类模型进行超参数优化。

一、超参数调优基础概念

超参数是机器学习算法在训练前需要设置的参数,与模型参数(训练过程中学习的参数)不同。常见的超参数包括:

  • 学习率
  • 正则化参数
  • 决策树的最大深度
  • 随机森林的树数量等

随机搜索是一种常用的超参数优化方法,它通过在定义的搜索空间中随机采样参数组合来寻找最优解,相比网格搜索更加高效。

二、环境准备与数据加载

首先我们需要准备Spark环境并加载乳腺癌数据集:

# 加载乳腺癌数据集
data = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/BreastCancer.parquet"
).cache()

# 将数据分为调优集和测试集
tune, test = data.randomSplit([0.80, 0.20])
tune.limit(10).toPandas()

这里我们将数据分为80%的调优集和20%的测试集,调优集用于超参数搜索,测试集用于最终模型评估。

三、模型选择与定义

我们将尝试三种不同的分类算法:

from synapse.ml.automl import TuneHyperparameters
from synapse.ml.train import TrainClassifier
from pyspark.ml.classification import (
    LogisticRegression,
    RandomForestClassifier,
    GBTClassifier,
)

# 定义基础模型
logReg = LogisticRegression()
randForest = RandomForestClassifier()
gbt = GBTClassifier()

# 转换为mmlspark的训练模型格式
smlmodels = [logReg, randForest, gbt]
mmlmodels = [TrainClassifier(model=model, labelCol="Label") for model in smlmodels]

选择的模型包括:

  1. 逻辑回归:线性分类模型
  2. 随机森林:集成树模型
  3. 梯度提升树(GBT):另一种集成树模型

四、构建超参数搜索空间

使用mmlspark的HyperparamBuilder定义每个模型的超参数搜索范围:

from synapse.ml.automl import *

paramBuilder = (
    HyperparamBuilder()
    # 逻辑回归的正则化参数范围0.1-0.3
    .addHyperparam(logReg, logReg.regParam, RangeHyperParam(0.1, 0.3))
    # 随机森林的树数量选项5或10
    .addHyperparam(randForest, randForest.numTrees, DiscreteHyperParam([5, 10]))
    # 随机森林的最大深度选项3或5
    .addHyperparam(randForest, randForest.maxDepth, DiscreteHyperParam([3, 5]))
    # GBT的最大分箱数范围8-16
    .addHyperparam(gbt, gbt.maxBins, RangeHyperParam(8, 16))
    # GBT的最大深度选项3或5
    .addHyperparam(gbt, gbt.maxDepth, DiscreteHyperParam([3, 5]))
)

searchSpace = paramBuilder.build()
randomSpace = RandomSpace(searchSpace)

这里我们为不同模型定义了不同类型的超参数:

  • RangeHyperParam:连续范围内的随机取值
  • DiscreteHyperParam:离散值集合中随机选择

五、执行随机搜索调优

使用TuneHyperparameters进行超参数调优:

bestModel = TuneHyperparameters(
    evaluationMetric="accuracy",  # 使用准确率作为评估指标
    models=mmlmodels,            # 要调优的模型列表
    numFolds=2,                  # 2折交叉验证
    numRuns=len(mmlmodels) * 2,  # 运行次数为模型数量的2倍
    parallelism=1,               # 并行度
    paramSpace=randomSpace.space(),  # 参数搜索空间
    seed=0,                      # 随机种子
).fit(tune)

关键参数说明:

  • numFolds:交叉验证折数,增加折数可以提高评估稳定性但会增加计算成本
  • numRuns:随机搜索的迭代次数,这里设置为模型数量的2倍
  • parallelism:并行任务数,可根据集群资源调整

六、评估最佳模型

获取并查看最佳模型信息:

print(bestModel.getBestModelInfo())  # 打印最佳模型信息
print(bestModel.getBestModel())     # 打印最佳模型对象

在测试集上评估模型性能:

from synapse.ml.train import ComputeModelStatistics

# 在测试集上进行预测
prediction = bestModel.transform(test)

# 计算模型指标
metrics = ComputeModelStatistics().transform(prediction)
metrics.limit(10).toPandas()

七、调优策略优化建议

  1. 搜索空间设计:根据领域知识合理设置参数范围,过宽的范围会降低搜索效率

  2. 评估指标选择:对于不平衡数据集,考虑使用F1-score或AUC代替准确率

  3. 运行次数设置:根据计算资源适当增加numRuns可以提高找到更优参数的概率

  4. 并行度调整:在集群资源充足时增加parallelism可以加速调优过程

结语

通过mmlspark的自动化超参数调优功能,我们能够高效地为乳腺癌分类任务找到最优的模型参数组合。随机搜索方法在保证搜索效率的同时,能够探索更广泛的参数空间,是实际项目中常用的调优策略。读者可以根据具体业务需求调整搜索策略和评估指标,以获得最佳的业务效果。

SynapseML SynapseML 项目地址: https://gitcode.com/gh_mirrors/mm/mmlspark

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柏滢凝Wayne

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值