本文假设您已明白决策树原理及CART生成算法
随机森林的算法核心思想有二:采样 和 完全分裂。采样又分为行采样和列采样,这里的行与列对应的就是样本与特征。完全分裂指的是决策树每一次分裂扩展节点时,能分裂必须分裂,分裂依据可以是信息增益或者增益率。
- 对于行采样,模型从M条数据集中随机采样m条数据,一般情况下m取M的平方根大小,分别作为每一棵决策树的训练集。行采样保证了每棵决策树使用的训练集各不相同,这在一定程度上抑制了over-fitting问题。
- 对于列采样,每一棵决策树都从M个特征中随机挑选m个特征作为节点分裂特征来计算,一般情况下m也取M的平方根大小。列采样具体又分为两种方式,一种是全局列采样,即同一棵树的建树过程均采用同一批采样特征;另一种是局部列采样,即每一次节点分裂的时候均单独随机挑选m个特征进行扩展。列采样进一步保证了随机森林不会出现over-fitting问题。
随机森林的最终输出由每一棵决策树的结果共同决定。如果是分类树则通过投票产生最终分类,如果是回归树则取所有结果的平均值。
随机森林优缺点:
优点
1、 在当前的很多数据集上,相对其他算法有着很大的优势,表现良好
2、它能够处理很高维度(feature很多)的数据,并且不用做特征选择( PS:特征子集是随机选择的)
3、在训练完后,它能够给出哪些feature比较重要( PS:http://blog.youkuaiyun.com/keepreder/article/details/47277517)
4、在创建随机森林的时候,对generlization error使用的是无偏估计,模型泛化能力强
5、训练速度快,容易做成并行化方法(PS:训练时树与树之间是相互独立的)
6、 在训练过程中,能够检测到feature间的互相影响
7、 实现比较简单
8、 对于不平衡的数据集来说,它可以平衡误差。
9、如果有很大一部分的特征遗失,仍可以维持准确度。
缺点:
1、随机森林已经被证明在某些噪音较大的分类或回归问题上会过拟
2、对于有不同取值的属性的数据,取值划分较多的属性会对随机森林产生更大的影响,所以随机森林在这种数据上产出的属性权值是不可信的。
随机森林生成算法:
根据下列算法而建造每棵树 :
1.用N来表示训练用例(样本)的个数,M表示特征数目。
2.输入特征数目m,用于确定决策树上一个节点的决策结果;其中m应远小于M。
3.从N个训练用例(样本)中以有放回抽样的方式,取样N次,形成一个训练集(即bootstrap取样),并用未抽到的用例(样本)作预测,评估其误差。
4.对于每一个节点,随机选择m个特征,决策树上每个节点的决定都是基于这些特征确定的。根据这m个特征,计算其最佳的分裂方式。
5.每棵树都会完整成长而不会剪枝,这有可能在建完一棵正常树状分类器后会被采用)。
以下代码可编译通过!
from random import seed
from random import randint
from csv import reader
# 建立一棵CART树
'''试探分枝'''
def data_split(index, value, dataset):
left, right = list(), list()
for row in dataset:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right
'''计算基尼指数'''
def calc_gini(groups, class_values):
gini = 0.0
total_size = 0
for group in groups:
total_size += len(group)
for group in groups:
size = len(group)
if size == 0:
continue
for class_value in class_values:
proportion = [row[-1] for row in group].count(class_value) / float(size)
gini += (size / float(total_size)) * (proportion * (1.0 - proportion))
return gini
'''找最佳分叉点'''
def get_split(dataset, n_features):
class_values = list(set(row[-1] for row in dataset))
b_index, b_value, b_score, b_groups = 999, 999, 999, None
features = list()
while len(features) < n_features:
index = randint(0, len(dataset[0]) - 2) # 往features添加n_features个特征(n_feature等于特征数的根号),特征索引从dataset中随机取
if index not in features:
features.append(index)
for index in features:
for row in dataset:
groups = data_split(index, row[index], dataset)
gini = calc_gini(groups, class_values)
if gini < b_score:
b_index, b_value, b_score, b_groups = index, row[index], gini, groups
return {'index': b_index, 'value': b_value, 'groups': b_groups} # 每个节点由字典组成
'''多数表决'''
def to_terminal(group):
outcomes = [row[-1] for row in group]
return max(set(outcomes), key=outcomes.count)
'''分枝'''
def split(node, max_depth, min_size, n_features, depth):
left, right = node['groups']
del (node['groups'])
if not left or not right:
node['left'] = node['right'] = to_terminal(left + right) # 叶节点不好理解
return
if depth >= max_depth:
node['left'], node['right'] = to_terminal(left), to_terminal(right)
return
if len(left) <= min_size:
node['left'] = to_terminal(left)
else:
node['left'] = get_split(left, n_features)
split(node['left'], max_depth, min_size, n_features, depth + 1)
if len(right) <= min_size:
node['right'] = to_terminal(right)
else:
node['right'] = get_split(right, n_features)
split(node['right'], max_depth, min_size, n_features, depth + 1)
'''建立一棵树'''
def build_one_tree(train, max_depth, min_size, n_features):
root = get_split(train, n_features)
split(root, max_depth, min_size, n_features, 1)
return root
'''用一棵树来预测'''
def predict(node, row):
if row[node['index']] < node['value']:
if isinstance(node['left'], dict):
return predict(node['left'], row)
else:
return node['left']
else:
if isinstance(node['right'], dict):
return predict(node['right'], row)
else:
return node['right']
# 随机森林类
class randomForest:
def __init__(self,trees_num, max_depth, leaf_min_size, sample_ratio, feature_ratio):
self.trees_num = trees_num # 森林的树的数目
self.max_depth = max_depth # 树深
self.leaf_min_size = leaf_min_size # 建立树时,停止的分枝样本最小数目
self.samples_split_ratio = sample_ratio # 采样,创建子集的比例(行采样)
self.feature_ratio = feature_ratio # 特征比例(列采样)
self.trees = list() # 森林
'''有放回的采样,创建数据子集'''
def sample_split(self, dataset):
sample = list()
n_sample = round(len(dataset) * self.samples_split_ratio)
while len(sample) < n_sample:
index = randint(0, len(dataset) - 2)
sample.append(dataset[index])
return sample
'''建立随机森林'''
def build_randomforest(self, train):
max_depth = self.max_depth
min_size = self.leaf_min_size
n_trees = self.trees_num
n_features = int(self.feature_ratio * (len(train[0])-1))#列采样,从M个feature中,选择m个(m<<M)
for i in range(n_trees):
sample = self.sample_split(train)
tree = build_one_tree(sample, max_depth, min_size, n_features)
self.trees.append(tree)
return self.trees
'''随机森林预测的多数表决'''
def bagging_predict(self, onetestdata):
predictions = [predict(tree, onetestdata) for tree in self.trees]
return max(set(predictions), key=predictions.count)
'''计算建立的森林的精确度'''
def accuracy_metric(self, testdata):
correct = 0
for i in range(len(testdata)):
predicted = self.bagging_predict(testdata[i])
if testdata[i][-1] == predicted:
correct += 1
return correct / float(len(testdata)) * 100.0
# 数据处理
'''导入数据'''
def load_csv(filename):
dataset = list()
with open(filename, 'r') as file:
csv_reader = reader(file)
for row in csv_reader:
if not row:
continue
dataset.append(row)
return dataset
'''划分训练数据与测试数据'''
def split_train_test(dataset, ratio=0.2):
#ratio = 0.2 # 取百分之二十的数据当做测试数据
num = len(dataset)
train_num = int((1-ratio) * num)
dataset_copy = list(dataset)
traindata = list()
while len(traindata) < train_num:
index = randint(0,len(dataset_copy)-1)
traindata.append(dataset_copy.pop(index))
testdata = dataset_copy
return traindata, testdata
# 测试
if __name__ == '__main__':
seed(1) #每一次执行本文件时都能产生同一个随机数
filename = 'sonar-all-data.csv'
dataset = load_csv(filename)
traindata,testdata = split_train_test(dataset, ratio=0.2)
max_depth = 20 #调参(自己修改) #决策树深度不能太深,不然容易导致过拟合
min_size = 1
sample_ratio = 1
trees_num = 20
feature_ratio=0.3
myRF = randomForest(trees_num, max_depth, min_size, sample_ratio, feature_ratio)
myRF.build_randomforest(traindata)
acc = myRF.accuracy_metric(testdata[:-1])
print('模型准确率:',acc,'%')