生成决策树的时候为什么有个黑块?怎么解决?只靠.replace(‘\n‘, ‘‘)就可以了吗?

文章讲述了在使用Python进行决策树可视化时,如何通过代码1中引入黑块的问题,以及通过IO流和字符串操作(如替换` `)来解决代码1中的黑块问题,以得到无黑块的输出图像。
部署运行你感兴趣的模型镜像

老规矩,先上代码:

代码1,这是黑块的,dot_data存到本地:

import pydotplus
import os
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from PIL import Image


Iris = load_iris()
X = Iris.data[:, 2:]  # petal length and width
y = Iris.target

tree_clf = DecisionTreeClassifier(max_depth=2, criterion='entropy')
tree_clf.fit(X, y)

os.makedirs("output/决策树/", exist_ok=True)  # 创建目录,如果不存在则创建
export_graphviz(
    tree_clf,
    out_file="output/决策树/Iris_tree.dot",
    feature_names=Iris.feature_names[2:],
    class_names=Iris.target_names,
    rounded=True,
    filled=True
)
graph = pydotplus.graph_from_dot_file("output/决策树/Iris_tree.dot")
dot_data = graph.to_string()
# dot_data = dot_data.replace('\n', '<br/>')
print(dot_data)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("output/决策树/Iris_tree.png")
# image = Image.open("output/决策树/Iris_tree.png")
# image.show()

代码2,这是没有黑块的,用的IO流:

from sklearn import datasets
from sklearn import tree
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from six import StringIO
import pydotplus

Iris = load_iris()
X = Iris.data[:, 2:]  # petal length and width
y = Iris.target

tree_clf = DecisionTreeClassifier(max_depth=2, criterion='entropy')
tree_clf.fit(X, y)

# 生成决策树可视化的DOT文件
dot_data = StringIO()
export_graphviz(
    tree_clf,
    out_file=dot_data,
    feature_names=Iris.feature_names[2:],
    class_names=Iris.target_names,
    rounded=True,
    filled=True
)

dot_data = dot_data.getvalue()
print(type(dot_data))
print(dot_data)
dot_data = dot_data.replace('\n', '')  # 去除右上角黑块
print(type(dot_data))
print(dot_data)

graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("output/决策树/iris_decision_tree.png")

生成的图像,对比如下,可以看到右图(代码1)有个黑块:

因为之前解决过黑块问题(就是代码2),所以我在代码1里也尝试

.replace('\n', '')

发现不行。为什么?

先告诉解决方法:

dot_data = dot_data.replace('"\\r\\n";', '')
dot_data = dot_data.replace('\n', '')

再告诉为什么:

请对比二者的dot_data(上图是代码1的第一次print(dot_data),下图是代码2的dot_data):

可以发现,存在本地的

dot_data

是包含了

"\r\n";

这个奇奇怪怪的东西的。

因此,需要先去除"\r\n";,也就是代码:

dot_data = dot_data.replace('"\\r\\n";', '')

然后跟IO流的一样,去除\n

dot_data = dot_data.replace('\n', '')

也就是这样的一个转换:

dot_data1
 digraph Tree {
node [color="black", fontname="helvetica", shape=box, style="filled, rounded"];
edge [fontname="helvetica"];
0 [fillcolor="#ffffff", label="petal width (cm) <= 0.8\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"];
1 [fillcolor="#e58139", label="entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"];
0 -> 1  [headlabel="True", labelangle=45, labeldistance="2.5"];
2 [fillcolor="#ffffff", label="petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"];
0 -> 2  [headlabel="False", labelangle="-45", labeldistance="2.5"];
3 [fillcolor="#4de88e", label="entropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"];
2 -> 3;
4 [fillcolor="#843de6", label="entropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"];
2 -> 4;
"\r\n";
}

dot_data2
 digraph Tree {
node [color="black", fontname="helvetica", shape=box, style="filled, rounded"];
edge [fontname="helvetica"];
0 [fillcolor="#ffffff", label="petal width (cm) <= 0.8\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"];
1 [fillcolor="#e58139", label="entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"];
0 -> 1  [headlabel="True", labelangle=45, labeldistance="2.5"];
2 [fillcolor="#ffffff", label="petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"];
0 -> 2  [headlabel="False", labelangle="-45", labeldistance="2.5"];
3 [fillcolor="#4de88e", label="entropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"];
2 -> 3;
4 [fillcolor="#843de6", label="entropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"];
2 -> 4;

}

dot_data3
 digraph Tree {node [color="black", fontname="helvetica", shape=box, style="filled, rounded"];edge [fontname="helvetica"];0 [fillcolor="#ffffff", label="petal width (cm) <= 0.8\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"];1 [fillcolor="#e58139", label="entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"];0 -> 1  [headlabel="True", labelangle=45, labeldistance="2.5"];2 [fillcolor="#ffffff", label="petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"];0 -> 2  [headlabel="False", labelangle="-45", labeldistance="2.5"];3 [fillcolor="#4de88e", label="entropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"];2 -> 3;4 [fillcolor="#843de6", label="entropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"];2 -> 4;}

需要无黑块的输出图片,就需要类似于dot_data3的输入。

IO流的不会有\r\n,因此只需要:

保存到本地的有\r\n,因此多一步:

大家可以自行print,看看与这个的不同,就知道如何避免黑块了。

因为这个的数据类型就是str,对其的操作其实是十分简单的,希望大家都能将黑块干掉!

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

以下使用Python实现决策树算法来解决相亲问题。决策树是一种常用的机器学习算法,可用于分类问题,相亲问题可看作一个分类问题,即判断是否愿意相亲。 ```python import numpy as np from collections import Counter # 定义计算信息熵的函数 def entropy(y): hist = np.bincount(y) ps = hist / len(y) return -np.sum([p * np.log2(p) for p in ps if p > 0]) # 定义决策树节点类 class Node: def __init__(self, feature=None, threshold=None, left=None, right=None, *, value=None): self.feature = feature self.threshold = threshold self.left = left self.right = right self.value = value def is_leaf_node(self): return self.value is not None # 定义决策树类 class DecisionTree: def __init__(self, min_samples_split=2, max_depth=100, n_feats=None): self.min_samples_split = min_samples_split self.max_depth = max_depth self.n_feats = n_feats self.root = None def fit(self, X, y): self.n_feats = X.shape[1] if not self.n_feats else min(self.n_feats, X.shape[1]) self.root = self._grow_tree(X, y) def _grow_tree(self, X, y, depth=0): n_samples, n_features = X.shape n_labels = len(np.unique(y)) # 停止条件 if (depth >= self.max_depth or n_labels == 1 or n_samples < self.min_samples_split): leaf_value = self._most_common_label(y) return Node(value=leaf_value) feat_idxs = np.random.choice(n_features, self.n_feats, replace=False) # 寻找最佳划分 best_feat, best_thresh = self._best_criteria(X, y, feat_idxs) left_idxs, right_idxs = self._split(X[:, best_feat], best_thresh) left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth + 1) right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth + 1) return Node(best_feat, best_thresh, left, right) def _best_criteria(self, X, y, feat_idxs): best_gain = -1 split_idx, split_thresh = None, None for feat_idx in feat_idxs: X_column = X[:, feat_idx] thresholds = np.unique(X_column) for threshold in thresholds: gain = self._information_gain(y, X_column, threshold) if gain > best_gain: best_gain = gain split_idx = feat_idx split_thresh = threshold return split_idx, split_thresh def _information_gain(self, y, X_column, split_thresh): # 父节点熵 parent_entropy = entropy(y) # 生成划分 left_idxs, right_idxs = self._split(X_column, split_thresh) if len(left_idxs) == 0 or len(right_idxs) == 0: return 0 # 计算加权子节点熵 n = len(y) n_l, n_r = len(left_idxs), len(right_idxs) e_l, e_r = entropy(y[left_idxs]), entropy(y[right_idxs]) child_entropy = (n_l / n) * e_l + (n_r / n) * e_r # 计算信息增益 ig = parent_entropy - child_entropy return ig def _split(self, X_column, split_thresh): left_idxs = np.argwhere(X_column <= split_thresh).flatten() right_idxs = np.argwhere(X_column > split_thresh).flatten() return left_idxs, right_idxs def _most_common_label(self, y): counter = Counter(y) most_common = counter.most_common(1)[0][0] return most_common def predict(self, X): return np.array([self._traverse_tree(x, self.root) for x in X]) def _traverse_tree(self, x, node): if node.is_leaf_node(): return node.value if x[node.feature] <= node.threshold: return self._traverse_tree(x, node.left) return self._traverse_tree(x, node.right) # 示例相亲数据 # 特征:年龄(0: 年轻, 1: 中年, 2: 老年),收入(0: 低, 1: 中, 2: 高),外貌(0: 普通, 1: 好看) X = np.array([ [0, 1, 1], [1, 2, 1], [2, 0, 0], [0, 2, 0], [1, 1, 1] ]) # 标签:是否愿意相亲(0: 不愿意, 1: 愿意) y = np.array([1, 1, 0, 1, 1]) # 创建并训练决策树模型 tree = DecisionTree(max_depth=3) tree.fit(X, y) # 预测新的相亲对象 new_person = np.array([[0, 2, 1]]) prediction = tree.predict(new_person) print("预测结果:", prediction) ``` ### 代码解释 1. **熵计算**:`entropy` 函数用于计算信息熵,是决策树划分的重要依据。 2. **节点类**:`Node` 类表示决策树的节点,包含特征、阈值、左右子节点和值。 3. **决策树类**:`DecisionTree` 类实现了决策树的训练和预测功能。 - `fit` 方法用于训练决策树。 - `_grow_tree` 方法递归地构建决策树。 - `_best_criteria` 方法寻找最佳划分特征和阈值。 - `_information_gain` 方法计算信息增益。 - `predict` 方法用于预测新数据。 ### 示例数据 示例数据包含了相亲对象的年龄、收入和外貌特征,以及是否愿意相亲的标签。通过训练决策树模型,可以对新的相亲对象进行预测。 ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值