随机森林回归模型
1. 模型简介
随机森林(Random Forest)是一种强大的集成学习算法,它通过构建多个决策树,并结合所有决策树的预测结果来提高模型的准确性和泛化能力。随机森林通常用于回归和分类任务,在处理高维数据、特征选择和异常值检测方面表现出色。
在这段代码中,我们使用了随机森林回归模型来进行回归预测任务,目标是通过输入特征预测连续值输出。
2. 决策树回归模型
随机森林由多个**决策树(Decision Tree)**组成,因此理解决策树的工作原理是理解随机森林的基础。
-
决策树回归是一种基于树状结构的模型,它通过递归地对数据集进行二分分裂,找到特征值与目标值之间的最佳分割点。
-
回归树分裂规则:
-
在每个节点选择一个特征,并基于该特征找到一个分割点,将数据分成两部分,使得这两部分的目标值方差最小化。
-
选择分割点 ( s ) 的标准是最小化节点的不纯度(Impurity),通常使用**均方误差(MSE)**作为度量指标:
MSE = 1 n ∑ i = 1 n ( y i − y ^ ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y})^2 MSE=n1i=1∑n(yi−y^)2其中:
- ( y_i ) 是样本的真实值。
- ( \hat{y} ) 是样本的预测值(即节点中的均值)。
-
3. 随机森林的核心原理
随机森林通过结合多个决策树来构建一个强大的预测模型。其核心思想包括样本随机性和特征随机性:
-
样本随机性(Bootstrap Aggregating, Bagging):
- 在训练每棵决策树时,随机森林从原始数据集中随机抽取样本,允许重复抽样(自助法)。
- 抽样得到的样本占总样本的 63.2%,其余的样本称为袋外样本(Out-of-Bag, OOB),用于估算模型的泛化误差。
-
特征随机性:
- 每次节点分裂时,随机森林从所有特征中随机选择一部分特征用于分裂。这一策略增加了模型的多样性,降低了过拟合的风险。
4. 随机森林回归的预测过程
在随机森林回归模型中,预测过程如下:
-
对于输入样本 ( x ),随机森林的每棵决策树 ( T_i ) 会给出一个预测结果 ( \hat{y}_i )。
-
随机森林的最终预测值是所有决策树预测结果的平均值:
y ^ = 1 n ∑ i = 1 n y ^ i \hat{y} = \frac{1}{n} \sum_{i=1}^{n} \hat{y}_i y^=n1i=1∑ny^i其中:
- ( n ) 是决策树的数量。
- ( \hat{y}_i ) 是第 ( i ) 棵决策树的预测结果。
这种集成策略通过对多个弱学习器(决策树)结果进行平均,降低了单个模型的方差,提升了模型的稳定性和预测性能。
5. 模型评估指标
在回归任务中,我们通常使用以下几种指标来评估模型的性能:
-
均方根误差(Root Mean Squared Error, RMSE):
RMSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{RMSE} = \sqrt{\frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2} RMSE=n1i=1∑n(yi−y^i)2- RMSE 衡量的是预测值与真实值之间的平方误差,值越小表示模型效果越好。
-
决定系数(R², Coefficient of Determination):
R 2 = 1 − ∑ i = 1 n ( y i − y ^ i ) 2 ∑ i = 1 n ( y i − y ˉ ) 2 R^2 = 1 - \frac{\sum_{i=1}^{n} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{n} (y_i - \bar{y})^2} R2=1−∑i=1n(yi−yˉ)2∑i=1n(yi−y^i)2- ( \bar{y} ) 是目标值的平均值。
- ( R^2 ) 越接近 1,表示模型的拟合效果越好。
-
平均绝对误差(Mean Absolute Error, MAE):
MAE = 1 n ∑ i = 1 n ∣ y i − y ^ i ∣ \text{MAE} = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i| MAE=n1i=1∑n∣yi−y^i∣- MAE 是预测值与真实值之间的平均绝对误差,值越小表示模型预测精度越高。
6. 特征重要性
随机森林可以评估每个特征对模型预测结果的贡献,称为特征重要性(Feature Importance)。
- 特征重要性计算:
- 随机森林通过**置换重要性(Permutation Importance)**来计算特征的重要性。
- 在袋外样本上,将某个特征的值随机打乱,再重新进行预测,计算打乱前后的误差差异。如果误差增加明显,说明该特征对模型贡献较大。
三、结合代码讲解
代码中使用了 TreeBagger
函数来实现随机森林回归,并对模型进行预测和评估。以下是关键代码段的解释:
-
构建随机森林模型:
net = TreeBagger(trees, x_train, y_train, 'OOBPredictorImportance', 'on', ... 'Method', 'regression', 'OOBPrediction', 'on', 'minleaf', leaf);
trees
指定决策树的数量,minleaf
设置最小叶子节点大小。'OOBPrediction', 'on'
开启袋外样本预测,用于估算模型误差。'OOBPermutedPredictorDeltaError'
开启特征重要性计算。
-
模型预测与评估:
re1 = predict(net, x_train); re2 = predict(net, x_test); error1 = sqrt(mean((pre1 - Y_train).^2)); error2 = sqrt(mean((pre2 - Y_test).^2)); R1 = 1 - norm(Y_train - pre1)^2 / norm(Y_train - mean(Y_train))^2; R2 = 1 - norm(Y_test - pre2)^2 / norm(Y_test - mean(Y_test))^2;
- 使用
predict
函数对训练集和测试集进行预测。 - 计算 RMSE、R² 和 MAE 等评估指标,衡量模型性能。
- 使用
-
特征重要性可视化:
figure; bar(import, 'green'); xlabel('特征'); ylabel('重要性'); title('特征重要性图');
- 使用
bar
函数绘制特征重要性图,展示各特征对模型预测的贡献。
- 使用
总结
随机森林是一种鲁棒性强、性能优越的回归模型,适用于高维数据和复杂非线性关系。通过结合多个决策树的预测结果,随机森林降低了单个模型的方差,提高了整体模型的稳定性。该代码中,使用随机森林对回归任务进行了建模,并提供了预测、误差评估和特征重要性分析,展示了模型在处理回归问题时的优势和效果。
如果有需要进一步改进的地方,可以考虑:
- 调整决策树数量
trees
和叶子节点大小minleaf
。 - 使用并行计算加速训练过程。
- 引入超参数优化(如交叉验证)来选择最佳模型参数。
Matlab代码手把手教运行
为了帮助更多的萌新更快上手数学建模建等竞赛,这里直接手把手教会如何直接使用本文中的基于随机森林的回归预测模型代码:
代码全文:
close all
clear
clc
%随机数种子固定结果
rng(2222)
%% 导入数据
res = readmatrix('回归数据.xlsx');
%% 数据归一化 索引
X = res(:,1:end-1);
Y = res(:,end);
x = mapminmax(X', 0, 1);
%保留归一化后相关参数
[y, psout] = mapminmax(Y', 0, 1);
%% 划分训练集和测试集
num = size(res,1);%总样本数
k = input('是否打乱样本(是:1,否:0):');
if k == 0
state = 1:num; %不打乱样本
else
state = randperm(num); %打乱样本
end
ratio = 0.8; %训练集占比
train_num = floor(num*ratio);
x_train = x(:,state(1: train_num))';
y_train = y(state(1: train_num))';
x_test = x(:,state(train_num+1: end))';
y_test = y(state(train_num+1: end))';
%% 训练模型
trees = 100; % 决策树数目
leaf = 3; % 最小叶子数,过小容易过拟合
wuc = 'on'; % 打开误差图
Importance = 'on'; % 计算特征重要性
net = TreeBagger(trees, x_train, y_train, 'OOBPredictorImportance', Importance,...
'Method','regression', 'OOBPrediction', wuc, 'minleaf', leaf);
import = net.OOBPermutedPredictorDeltaError; % 重要性
%% 预测
re1 = predict(net, x_train);
re2 = predict(net, x_test );
%% 数据反归一化
%实际值
Y_train = Y(state(1: train_num));
Y_test = Y(state(train_num+1:end));
%预测值
pre1 = mapminmax('reverse', re1, psout);
pre2 = mapminmax('reverse', re2, psout);
%% 均方根误差
error1 = sqrt(mean((pre1 - Y_train).^2));
error2 = sqrt(mean((pre2 - Y_test).^2));
%% 相关指标计算
% R2
R1 = 1 - norm(Y_train - pre1)^2 / norm(Y_train - mean(Y_train))^2;
R2 = 1 - norm(Y_test - pre2)^2 / norm(Y_test - mean(Y_test ))^2;
% MAE
mae1 = mean(abs(Y_train - pre1 ));
mae2 = mean(abs(pre2 - Y_test ));
disp('训练集预测精度指标如下:')
disp(['训练集数据的R2为:', num2str(R1)])
disp(['训练集数据的MAE为:', num2str(mae1)])
disp(['训练集数据的RMSE为:', num2str(error1)])
disp('测试集预测精度指标如下:')
disp(['测试集数据的R2为:', num2str(R2)])
disp(['测试集数据的MAE为:', num2str(mae2)])
disp(['测试集数据的RMSE为:', num2str(error2)])
figure
plot(1: train_num, Y_train, 'r-^', 1: train_num, pre1, 'b-+', 'LineWidth', 1)
legend('真实值','预测值')
xlabel('样本点')
ylabel('预测值')
title('训练集预测结果对比')
%%画图
figure
plot(1: num-train_num, Y_test, 'r-^', 1: num-train_num, pre2, 'b-+', 'LineWidth', 1)
legend('真实值','预测值')
xlabel('样本点')
ylabel('预测值')
title('测试集预测结果对比')
%% 训练集百分比误差图
figure
plot((pre1 - Y_train )./Y_train, 'b-o', 'LineWidth', 1)
legend('百分比误差')
xlabel('样本点')
ylabel('误差')
title('训练集百分比误差曲线')
%% 测试集百分比误差图
figure
plot((pre2 - Y_test )./Y_test, 'b-o', 'LineWidth', 1)
legend('百分比误差')
xlabel('样本点')
ylabel('误差')
title('测试集百分比误差曲线')
%% 拟合图
figure;
plotregression(Y_train, pre1, '训练集', ...
Y_test, pre2, '测试集');
set(gcf,'Toolbar','figure');
%% 绘制误差曲线
figure
plot(1: trees, oobError(net), 'r--O', 'LineWidth', 1)
legend('误差迭代曲线')
xlabel('决策树(迭代次数)')
ylabel('误差')
grid
%% 绘制特征重要性
figure
bar(import,'green')
yticks([])
xlabel('特征')
ylabel('重要性')
回归数据:
x1 | x2 | x3 | x4 | x5 | y |
---|---|---|---|---|---|
3.036226 | 5613.881 | 217.475 | 68.43483 | 16.82093 | 48.47907612 |
4.391658 | 8450.278 | 231.6231 | 60.47683 | 10.7205 | 64.12185185 |
5.248597 | 8634.475 | 232.9854 | 68.74011 | 11.54856 | 67.67148734 |
5.121236 | 7253.694 | 244.4692 | 65.36061 | 14.36191 | 59.32753057 |
4.649021 | 8647.125 | 268.3052 | 63.80132 | 12.04857 | 67.98570391 |
4.267209 | 5445.033 | 257.7701 | 61.68997 | 7.055625 | 46.47506598 |
4.804512 | 5848.875 | 234.0968 | 67.14467 | 7.656134 | 47.18442579 |
3.825803 | 6583.543 | 270.7575 | 61.44453 | 16.65654 | 54.0487296 |
4.787534 | 9546.304 | 225.8279 | 67.89351 | 11.43665 | 69.90372309 |
5.580797 | 5481.803 | 287.9043 | 62.20432 | 11.26927 | 47.32012957 |
3.658919 | 9482.116 | 278.8093 | 61.20217 | 12.37418 | 71.68429695 |
3.829985 | 9128.034 | 225.2187 | 64.81423 | 16.70461 | 68.18735427 |
5.295785 | 6804.695 | 282.6947 | 67.13434 | 13.61871 | 58.0309376 |
5.554308 | 5340.983 | 277.2833 | 69.39466 | 7.804352 | 48.77931811 |
3.159579 | 8717.524 | 209.9133 | 66.73862 | 16.308 | 67.56410903 |
前几列为数据特征,最后一列为因变量,类似制作表格导入即可,无需修改代码!