机器学习系列二:决策树算法

本文介绍了决策树算法的基本概念,包括其结构特点、熵的概念及计算公式,并详细解析了ID3算法的工作原理。此外,还提供了使用Python和scikit-learn构建决策树的实际代码示例。

机器学习中分类和预测算法的评估:

  • 准确率
  • 速度
  • 强壮性
  • 可规模性
  • 可解释性

一、什么是决策树(decision tree)

决策树是一个类似于流程图的树结构:其中每个内部节点表示在一个属性上的测试,每个分支代表一个属性输出,而每个树节点代表类或类分布,树的最顶层是根节点。输入图片说明

熵(entropy):如何度量信息。

一条信息的信息量大小和它的不确定性有直接的关系,要搞清楚一件非常非常不确定的事情,或者是我们一无所知的事情,需要了解大量的信息=====》信息量的度量就等于不确定性的多少。

计算信息熵的公式: 输入图片说明

公式意思是每一个可发生情况的概率乘以一个以2为底的对数,所有情况相加起来。

变量的不确定性越大,熵也就越大,比特(bit)来衡量信息的多少。

二、决策树归纳算法(ID3算法)

选择属性判断节点

信息获取量:Gain(A) = Info(D)- Info_A(D)

通过A来作为节点分类获取了多少信息。下图是根据年龄、收入、是否为学生、信用度来判断是否买了电脑。

输入图片说明输入图片说明

根据年龄

输入图片说明输入图片说明输入图片说明输入图片说明

决策树的优点:

直观、便于理解,小规模数据集有效

决策树的缺点:

  1. 处理连续变量不好
  2. 类别较多时,错误增加的比较快
  3. 可规模性一般

下面是实例代码: 代码中的example.csv文件是根据图片上的内容整理的csv文件,如下: 输入图片说明

首先需要先安装scikit-learn、numpy等包。

pip install scikit-learn

# -*- coding:utf-8 -*-
from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import preprocessing
from sklearn import tree
from sklearn.externals.six import StringIO

#读取csv文件
allElectronicsData = open('example.csv', 'rb')
reader = csv.reader(allElectronicsData)
headers = reader.next()

# headers是csv文件的第一行数据
print(headers)
# featurelist是特征值,例如年龄、信用度等
featureList = []
# labelList是标签值,也就结果即是否买电脑
labelList = []

for row in reader:
    labelList.append(row[len(row) - 1])
    rowDict = {}
    for i in range(1, len(row) - 1):
        rowDict[headers[i]] = row[i]
    featureList.append(rowDict)
print(featureList)

allElectronicsData.close()

vec = DictVectorizer()
# 特征值的列表,是0,1的列表
dummyX = vec.fit_transform(featureList).toarray()
print("dummyX:" + str(dummyX))
print(type(dummyX))
print(vec.get_feature_names())
print("labelList:" + str(labelList))

lb = preprocessing.LabelBinarizer()
dummyY = lb.fit_transform(labelList)
print("dummyY:" + str(dummyY))

# 根据信息熵来构建树的节点
clf = tree.DecisionTreeClassifier(criterion="entropy")
# 建模
clf = clf.fit(dummyX, dummyY)
print("clf:" + str(clf))
# 生成一个dot树,可用 dot命令生成pdf文件
with open("allElectronicInformationGainOri.dot", "w") as f:
    f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file=f)

# oneRowX是拷贝出一条数据,修改一些特征值,来进行测试
oneRowX = dummyX[0, :]
print("oneRowX:" + str(oneRowX))
newRowX = oneRowX

newRowX[0] = 1
newRowX[2] = 0
print("newRowX:" + str(newRowX))

# predicedY是预测的结果
predictedY = clf.predict(newRowX)
print("predictedY:" + str(predictedY))

生成的allElectronicInformationGainOri.dot文件可以用dot生成pdf文件

sudo apt-get install graphviz

dot -Tpdf allElectronicInformationGainOri.dot -o test.pdf

生成的pdf文件如下:

输入图片说明

转载于:https://my.oschina.net/zhangyangyang/blog/860498

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值