from numpy import *
def loadSimpData():
datMat = matrix([[1., 2.1],
[2., 1.1],
[1.3, 1.],
[1., 1.],
[2., 1.]])
classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
return datMat, classLabels
def loadDataSet(fileName): # general function to parse tab -delimited floats
numFeat = len(open(fileName).readline().split('\t')) # get number of fields
dataMat = [];
labelMat = []
fr = open(fileName)
for line in fr.readlines():
lineArr = []
curLine = line.strip().split('\t')
for i in range(numFeat - 1):
lineArr.append(float(curLine[i]))
dataMat.append(lineArr)
labelMat.append(float(curLine[-1]))
return dataMat, labelMat
def stumpClassify(dataMatrix, dimen, threshVal, threshIneq): # just classify the data
retArray = ones((shape(dataMatrix)[0], 1))
if threshIneq == 'lt':
retArray[dataMatrix[:, dimen] <= threshVal] = -1.0
else:
retArray[dataMatrix[:, dimen] > threshVal] = -1.0
return retArray
def buildStump(dataArr, classLabels, D):
dataMatrix = mat(dataArr);
labelMat = mat(classLabels).T
m, n = shape(dataMatrix)
numSteps = 10.0;
bestStump = {};
bestClasEst = mat(zeros((m, 1)))
minError = inf # init error sum, to +infinity 正无穷
for i in range(n): # loop over all dimensions
rangeMin = dataMatrix[:, i].min();
rangeMax = dataMatrix[:, i].max();
# print(rangeMax,rangeMin)
stepSize = (rangeMax - rangeMin) / numSteps
# print(range(-1,int(numSteps)+1))
for j in range(-1, int(numSteps) + 1): # loop over all range in current dimension
for inequal in ['lt', 'gt']: # go over less than and greater than
threshVal = (rangeMin + float(j) * stepSize)
predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal) # call stump classify with i, j, lessThan
errArr = mat(ones((m, 1)))
errArr[predictedVals == labelMat] = 0
weightedError = D.T * errArr # calc total error multiplied by D
print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f,minerr=%0.3f" % (i, threshVal, inequal, weightedError,minError))
if weightedError < minError:
minError = weightedError #错误率
bestClasEst = predictedVals.copy() #最好结果错了两个标签
bestStump['dim'] = i #保存最好的列特征
bestStump['thresh'] = threshVal
bestStump['ineq'] = inequal
return bestStump, minError, bestClasEst
D = mat(ones((5,1))/5)
data,label = loadSimpData()
bestStump, minError, bestClasEst = buildStump(data,label,D)
print(bestStump)
print(minError)
print(bestClasEst)
单层决策树实现(AdaBoost)
最新推荐文章于 2024-08-05 15:35:00 发布