决策树的实现代码
class DecisionNode(object):
def __init__(self, feature_i=None, threshold=None,
value = None, true_branch=None, false_branch=None):
self.feature_i = feature_i
self.threshold = threshold
self.value = value
self.true_branch = true_branch
self.false_branch = false_branch
class DecisionTree(object):
def __init__(self, min_sample_split=2, min_impurity=1e-7,
max_depth=float("inf")):
self.root = None
self.min_sample_split = min_sample_split
self.min_impurity = min_impurity
self.max_depth = max_depth
### Function to calculate impurity
self._impurity_caculation = None
### Function to determine value of leaf node
self._leaf_value_caculation = None
def fit(self, X, y):
self.root = self._build_tree(X,y)
def _build_tree(self, X, y, current_depth=0):
largest_impurity = 0
best_criteria = None # Feature index and threshold
best_sets = None # Subsets of the data
X_y = np.concatenate((X, y), axis=1)
n_samples , n_features = np.shape(X)
if n_samples >= self.min_sample_split and current_depth <=self.max_depth:
for feature_i in range(n_features):
unique_values = np.unique(feature_values)
for threshold in unique_values:
Xy1, Xy2 = divide_on_feature(X_y, feature_i, threshold)
y1 = Xy1[:, n_features:]
y2 = Xy2[:, n_features:]
impurity = self._impurity_caculation(y , y1, y2)
if impurity > largest_impurity:
largest_impurity = impurity
best_criteria = {'feature_i': feature_i,
'threshold': threshold}
best_sets = {'leftX': Xy1[:, :n_features],
'lefty': Xy1[:, n_features:],
'rightX': Xy2[:, :n_features],
'righty': Xy2[:,n_features:]}
if largest_impurity>self.min_impurity:
true_branch = self._build_tree(best_sets['leftX'], best_sets['lefty'],current_depth+1)
false_branch = self._build_tree(best_sets['rightX'], best_sets['righty'],current_depth+1)
return DecisionNode(feature_i = best_criteria['feature_i'],
threshold = best_criteria['threshold'],
value = None,
true_branch=true_branch,
false_branch=false_branch)
def predict_value(self, x, tree=None):
if tree is None:
tree = self.root
if tree.value is not None:
return tree.value
feature_value = x[tree.feature_i]
branch = tree.false_branch
if isinstance(feature_value ,int) or isinstance(feature_value, float):
if feature_value >= tree.threshold:
branch= tree.true_branch
elif feature_value == tree.threshold:
branch = tree.true_branch
return self.predict_value(x, branch)
def predict(self, X):
y_pred = []
for x in X:
y_pred.append(self.predict_value(x))
return y_pred
可以根据ID3还是CART自定义self._impurity_calculation函数体,并继承上述类
class ClassificationTree(DecisionTree):
#### here is ID3
def _calculate_information_gain(self, y, y1, y2):
# Calculate information gain
p = len(y1) / len(y)
entropy = calculate_entropy(y)
info_gain = entropy - p * \
calculate_entropy(y1) - (1 - p) * \
calculate_entropy(y2) ### entropy calculation omitted
return info_gain
def _majority_vote(self, y):
most_common = None
max_count = 0
for label in np.unique(y):
# Count number of occurences of samples with label
count = len(y[y == label])
if count > max_count:
most_common = label
max_count = count
return most_common
def fit(self, X, y):
self._impurity_calculation = self._calculate_information_gain
self._leaf_value_calculation = self._majority_vote
super(ClassificationTree, self).fit(X, y)