问题背景
今天在GEE上跑随机森林回归的时候,需要做一下交叉验证。然找遍全网,几乎全都是收费的,没办法,只好亲自写了个交叉验证,现在将代码共享出来给大家学习。
GEE交叉验证
不说废话,直接上代码。
// 在 training 要素集中增加一个 random 属性,值为 0 到 1 的随机数
var withRandom = crop.randomColumn({
columnName: 'random',
seed: 2,
distribution: 'uniform'
});
var k = 5; // 交叉验证的折数
// 准备 k 折数据集
var partitions = ee.List.sequence(0, k - 1).map(function(i) {
var split1 = ee.Number(i).divide(k)
var j = ee.Number(i).add(1)
var split2 = j.divide(k)
var testing = withRandom.filter(ee.Filter.gte('random', split1))
.filter(ee.Filter.lt('random', split2));
var training = withRandom.filter(ee.Filter.lt('random', split1))
.merge(withRandom.filter(ee.Filter.gte('random', split2)));
return ee.Dictionary({
'fold': i,
'training': training,
'testing': testing
});
});
// 函数:进行 k 折交叉验证
var foldResults = partitions.map(function(partition) {
partition = ee.Dictionary(partition);
var training = ee.FeatureCollection(partition.get('training'));
var testing = ee.FeatureCollection(partition.get('testing'));
// 训练随机森林模型
var classifier = ee.Classifier.smileRandomForest(100, null, 1, 0.5, null, 0).setOutputMode('REGRESSION')
.train({
features: training,
classProperty: 'totalPre',
inputProperties: properties,
});
// 使用模型进行预测
var predictions = testing.classify(classifier);
// 计算决定系数 R^2
var meanActual = ee.Number(testing.reduceColumns({
reducer: ee.Reducer.mean(),
selectors: ['totalPre'] // 修改为你的目标属性
}).get('mean'));
var ssTotal = ee.Number(testing.map(function(feature) {
var diff = ee.Number(feature.get('totalPre')).subtract(meanActual); // 修改为你的目标属性
return feature.set('sqTotal', diff.pow(2));
}).reduceColumns({
reducer: ee.Reducer.sum(),
selectors: ['sqTotal']
}).get('sum'));
var ssResidual = ee.Number(predictions.map(function(feature) {
var diff = ee.Number(feature.get('totalPre')).subtract(feature.get('classification')); // 修改为你的目标属性
return feature.set('sqResidual', diff.pow(2));
}).reduceColumns({
reducer: ee.Reducer.sum(),
selectors: ['sqResidual']
}).get('sum'));
var rSquared = ee.Number(1).subtract(ssResidual.divide(ssTotal));
return ee.Dictionary({
'fold': partition.get('fold'),
'rSquared': rSquared,
'classifier': classifier
});
});
这里我进行了五折交叉验证,大家如果有需要可以自行设置折数。在GEE中,凡是用到map函数的地方,客户端操作都是不允许的,在进行操作前,最好包装成服务端的变量。
结果查看
// 收集所有折的结果
var results = ee.FeatureCollection(foldResults.map(function(result) {
result = ee.Dictionary(result);
return ee.Feature(ee.Geometry.Point([0, 0]), {
'fold': result.get('fold'),
'rSquared': result.get('rSquared')
});
}));
此处可以查看交叉验证后每一折的结果,其中fold是折数,rSquared是决定系数。