使用online-ml/river实现多类别分类任务详解
river 🌊 Online machine learning in Python 项目地址: https://gitcode.com/gh_mirrors/ri/river
多类别分类概述
多类别分类是机器学习中的一项核心任务,其目标是从一组固定的类别中预测出正确的类别。与二分类不同,多类别分类需要处理两个以上的类别选项。在online-ml/river框架中,分类器会为每个可能的类别输出一个概率分布,表示该样本属于各个类别的可能性。
数据集介绍
我们以图像分割数据集(ImageSegments)为例,该数据集包含2310个样本,每个样本有18个特征,需要将图像片段分类为7个类别:
- brickface(砖墙)
- sky(天空)
- foliage(树叶)
- cement(水泥)
- window(窗户)
- path(小路)
- grass(草地)
数据流处理
online-ml/river的一个显著特点是其流式数据处理能力。我们可以像处理数据流一样遍历整个数据集:
from river import datasets
dataset = datasets.ImageSegments()
for x, y in dataset:
# x是特征字典,y是类别标签
pass
查看第一个样本的特征和标签:
x, y = next(iter(dataset))
print(x) # 打印特征字典
print(y) # 打印类别标签'path'
模型初始化与训练
我们使用Hoeffding决策树作为分类器。Hoeffding树是一种适用于数据流环境的增量决策树算法,它能够在有限内存条件下高效学习。
from river import tree
model = tree.HoeffdingTreeClassifier()
初始状态下,模型对新样本的预测为空,因为它尚未学习任何数据模式:
print(model.predict_proba_one(x)) # 输出: {}
print(model.predict_one(x)) # 输出: None
增量学习
在线学习的特点是模型可以逐个样本进行学习。当我们用第一个样本更新模型后:
model.learn_one(x, y)
model.predict_proba_one(x) # 输出: {'path': 1.0}
此时模型只认识'path'这一个类别,因此会给出100%的概率。随着更多样本的输入,模型会逐步学习并识别所有7个类别。
渐进式验证
在线学习中常用的评估方法是渐进式验证(progressive validation),即在预测后立即用真实标签更新模型,并计算指标:
from river import metrics
model = tree.HoeffdingTreeClassifier()
metric = metrics.ClassificationReport()
for x, y in dataset:
y_pred = model.predict_one(x)
model.learn_one(x, y)
if y_pred is not None:
metric.update(y, y_pred)
print(metric)
online-ml/river提供了便捷的评估函数:
from river import evaluate
model = tree.HoeffdingTreeClassifier()
metric = metrics.ClassificationReport()
evaluate.progressive_val_score(dataset, model, metric)
评估结果分析
从分类报告中我们可以看到模型在各个类别上的表现:
- 草地(grass)和天空(sky)分类效果最好(F1分数>98%)
- 树叶(foliage)和窗户(window)分类效果较差(F1分数约30-53%)
- 整体准确率为77.61%
这种差异可能源于某些类别特征更明显,而有些类别间可能存在较大相似性。
在线学习的优势
- 适应新类别:模型能够动态适应数据流中出现的新类别
- 内存高效:不需要存储全部训练数据
- 实时更新:模型可以持续改进,适应数据分布变化
- 可扩展性:适合处理大规模或无限数据流
实际应用建议
- 对于类别不平衡的数据流,考虑使用带权重的评估指标
- 监控模型性能随时间的变化,检测概念漂移
- 根据业务需求选择合适的评估指标(精确率、召回率或F1分数)
- 尝试不同的在线学习算法比较性能
通过online-ml/river框架,开发者可以轻松构建和评估流式多类别分类系统,特别适合实时数据处理场景。
river 🌊 Online machine learning in Python 项目地址: https://gitcode.com/gh_mirrors/ri/river
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考