数据集
- 数据集:MNIST数据,图片大小是28×28的,10个类别,使用数据的原始特征,所有每个样本有28×28=784个特征。
- 图片中的每个元素值都经过二值化
- 剪枝使用的是预剪枝。
代码
import cv2
import time
import logging
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
def binaryzation(img):
for i in range(len(img)):
img_1 = img[i]
cv_img = img_1.astype(np.uint8)
cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)
img[i] = cv_img
class Tree(object):
def __init__(self, node_type, Class=None, feature=None):
self.node_type = node_type
self.Child = {
}
self.Class = Class
self.feature = feature
def add_tree(self, val, tree):
self.Child[val] = tree
def predict(self, features):
if self.node_type == 'leaf':
return self.<