简介:一套开箱即用的纯Python决策树实现,不依赖sklearn内置Tree类,完整复现ID3/C4.5核心逻辑。支持两类典型任务:用鸢尾花全部4个特征做多分类(含二维决策边界图、全维度建模、树结构图生成、不同深度下的准确率变化曲线),以及对西瓜数据集进行连续属性切分与好坏二分类(输出可读树结构、.dot文件及渲染图)。所有脚本含逐行中文注释,运行iris_test1.py得双特征分类边界图,iris_test2.py输出四特征模型+决策树图+深度-精度关系图,watermelon_test.py完成西瓜数据预处理、最优分割点计算与树形打印。data目录内置标准iris.data和watermelon.data;figure目录存放12张实测图(含6张png+6张dot源文件);项目操作说明.md详细列出Python 3.7+环境配置、graphviz安装步骤、pydot调用方法、单文件执行顺序及常见报错修复方案;requirements.txt声明scikit-learn、matplotlib、graphviz、pydot四项依赖,无其他第三方要求。
1. 项目概述:为什么我坚持手写一棵决策树?
你有没有试过,在调用 sklearn.tree.DecisionTreeClassifier().fit(X, y) 的那一刻,心里突然冒出一个念头:这棵树到底长什么样?它怎么决定在花瓣宽度 2.45 厘米处切一刀,而不是 2.46?它凭什么认为“萼片长度 > 5.5 且花瓣长度 < 3.0”就能把山鸢尾和变色鸢尾干净利落地分开?不是说 sklearn 不好——它极快、极稳、工业级可靠;但正因为它太好,我们反而容易变成“黑箱操作员”,只输入数据、等待结果,却对内部的分裂逻辑、信息增益计算、连续值离散化策略一无所知。
这个项目就是为了解决这个问题而生的。它不是另一个“用 sklearn 做鸢尾花分类”的教程,而是一套可触摸、可调试、可逐行理解的决策树操作系统。我花了整整三周时间,从零开始重写了 ID3 和 C4.5 的核心骨架:不调用任何 sklearn.tree 的私有方法,不 import tree._tree,所有节点结构、分裂判断、递归构建、预测逻辑,全部用原生 Python 实现。它跑得当然不如 sklearn 快(毕竟没 Cython 加速),但它像一台透明玻璃罩里的机械钟表——你能看清每一个齿轮如何咬合,每一根游丝如何摆动,每一次分裂背后的信息熵变化是多少比特。
项目名字里那个“手写”,不是情怀修饰词,是硬性技术承诺。你打开 tree.py,会看到 class TreeNode: 里明明白白写着 self.feature_idx, self.threshold, self.children, self.is_leaf, self.class_label;你点开 decisiontree.py,_best_split() 函数里是完整的基尼不纯度(Gini Impurity)计算循环,_information_gain() 里是香农熵的手动累加;你在 watermelon_test.py 中看到的,不是一行 dt.fit(X, y),而是对西瓜数据中“敲声”“纹理”“色泽”等连续属性,手动遍历所有可能切分点、计算每个切分带来的纯度提升、最终选出最优阈值的全过程。
它支持两类典型场景:一是离散型多分类任务——用鸢尾花全部 4 个特征(萼片长/宽、花瓣长/宽)区分 3 类花朵,不仅输出准确率,还生成二维决策边界图(iris_test1.py)、全维度建模后的树形结构图(.dot + Graphviz 渲染)、以及最关键的——不同最大深度下模型性能的衰减曲线(iris_test2.py);二是连续型二分类任务——处理西瓜数据集(共 17 条记录,含“色泽”“根蒂”“敲声”“纹理”“脐部”“触感”6 个属性,其中多个为连续数值),实现真正的“最优切分点搜索”,而非简单四舍五入或等宽分箱。
所有脚本自带逐行中文注释,不是“此处初始化变量”这种废话,而是“此处计算第 j 个特征在阈值 t 下的加权基尼指数,公式为:Gini(D_t) = |D_left|/|D| × Gini(D_left) + |D_right|/|D| × Gini(D_right)”;project_operation_guide.md 不是冷冰冰的命令列表,而是按真实踩坑顺序写的排错指南——比如 Windows 上 Graphviz 的 bin 目录为何必须加进系统 PATH,pydot 报 GraphViz’s executable not found 时该检查哪三个路径,matplotlib 中文乱码为何要提前设置 plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']。整个包就像一位坐在你工位旁的老同事,一边敲代码一边给你讲解:“你看,这里 if len(np.unique(y)) == 1 就是剪枝的第一道防线,只要当前子集全是同一类,就立刻停住,不再往下分——这叫‘纯度满足’,比 max_depth 更早生效。”
如果你是刚学完《统计学习方法》第5章的学生,这个包能帮你把公式 Gain(D, a) = Ent(D) - \sum_{v=1}^V \frac{|D^v|}{|D|} Ent(D^v) 变成屏幕上跳动的数字;如果你是工作三年想补算法底层的工程师,它能让你在调参时多一层底气——当模型在深度=5时准确率骤降,你知道这不是玄学,而是训练集被过度细分导致泛化能力崩塌;如果你只是好奇“AI 怎么做决定”,运行一遍 iris_test1.py,看着那张红蓝黄三色交织、边界锯齿分明的散点图,你就直观理解了什么叫“轴平行分割”。
它不追求 SOTA,不卷参数规模,它的价值在于可解释性、可教学性、可追溯性。当你双击打开 iris_test2.dot,看到 node0 [label="petal length <= 2.45"]; node1 [label="petal width <= 1.75"]; 这样的文本,你就握住了机器判断的原始逻辑。这才是真正属于你的决策树,不是调包接口返回的一个黑盒对象,而是一棵你亲手栽种、修剪、浇水、最终看着它开花结果的树。
2. 整体设计与思路拆解:ID3/C4.5 的 Python 化重构
2.1 为什么放弃 sklearn.tree,坚持纯手写?
很多人第一反应是:“sklearn 不是现成的吗?何必重复造轮子?”这个问题我问了自己不下十遍。答案不是为了炫技,而是源于三个无法绕开的实践痛点:
第一,教学穿透力不足。sklearn 的 DecisionTreeClassifier 是一个高度封装的 API,其 _tree 属性虽暴露了底层结构,但那是 Cython 编译后的 Tree 对象,内部字段如 node_count, capacity, max_depth 全是 C 级别内存指针,Python 层无法直接读取分裂阈值的计算过程。你想看某次分裂时“萼片长度”特征的信息增益具体是多少?不行。你想在分裂前插入断点,观察 X[:, j] 排序后各候选切分点对应的基尼指数数组?不行。它像一辆高速列车,你只能买票上车,无法掀开引擎盖看活塞运动。
第二,连续属性处理逻辑不透明。C4.5 对连续特征的处理是其核心创新之一:对特征值排序后,在每两个相邻值的中点尝试切分,选使信息增益最大的那个点。sklearn 默认使用“最优切分”(splitter='best'),但它的源码藏在 sklearn/tree/_tree.pyx 里,涉及 compute_best_split 函数和 SplitRecord 结构体,对 Python 开发者极不友好。而本项目中,_find_best_threshold_for_continuous_feature() 函数用不到 20 行 Python 就清晰呈现了全过程:sorted_vals = np.sort(X_col); candidates = (sorted_vals[:-1] + sorted_vals[1:]) / 2; for t in candidates: gain = _information_gain(...)。你可以轻松修改 candidates 生成逻辑,比如改成等频分箱(quantile-based)或加入最小样本约束,而无需编译整个 sklearn。
第三,可视化与分析链路断裂。sklearn 能用 export_graphviz 导出 .dot,但那是基于训练完成后的静态树结构;它无法在训练过程中动态记录每次分裂的增益值、样本分布、特征重要性累积过程。而本项目中,iris_test2.py 不仅生成最终树图,还同步绘制 depth_vs_accuracy.png,其数据来源于训练时每层递归返回的 stats 字典——里面存着 depth, n_samples, gini_before, gini_after, best_gain。这种“训练即分析”的一体化设计,只有手写才能实现。
所以,这不是拒绝工业级工具,而是构建一个可显微、可干预、可延展的学习沙盒。它和 sklearn 的关系,就像乐高基础颗粒和成品变形金刚——前者让你理解连接原理,后者让你快速搭建功能。本项目选择前者,因为它是理解后者的基础。
2.2 ID3 与 C4.5 的融合策略:离散与连续的统一建模框架
严格来说,ID3 只处理离散特征,C4.5 才扩展支持连续特征。但实际应用中,我们面对的数据集往往是混合的:鸢尾花 4 个特征全是连续数值,西瓜数据集中既有“色泽”(青绿/乌黑/浅白)这样的离散标签,也有“敲声”(沉闷/浊响/清脆)这种可视为有序离散的变量,还有“密度”“含糖率”这类明确连续值。因此,本项目没有机械割裂 ID3/C4.5,而是构建了一个特征感知型分裂引擎。
其核心在于 DecisionTree._best_split() 函数中的类型判断分支:
# 伪代码示意
for feature_idx in range(X.shape[1]):
X_col = X[:, feature_idx]
if self._is_continuous(X_col): # 判断标准:唯一值数量 > len(X)//5 且 dtype 为 float
# 走 C4.5 连续路径:排序 → 生成候选阈值 → 遍历计算增益
thresholds = self._find_best_threshold_for_continuous_feature(X_col, y)
best_gain, best_thresh = self._evaluate_splits_continuous(X_col, y, thresholds)
candidate_splits.append(('continuous', feature_idx, best_thresh, best_gain))
else:
# 走 ID3 离散路径:对每个唯一值构造子集 → 计算加权熵
unique_vals = np.unique(X_col)
for val in unique_vals:
mask = (X_col == val)
gain = self._information_gain(y, y[mask])
candidate_splits.append(('discrete', feature_idx, val, gain))
这个设计的关键洞察是:连续与离散的本质区别不在数据类型,而在分裂方式。连续特征需要寻找一个实数阈值 t,将空间切成 X_j ≤ t 和 X_j > t 两半;离散特征则需对每个取值 v 构造一个布尔条件 X_j == v,形成多路分支。因此,_is_continuous() 的判断逻辑并非简单看 dtype,而是结合统计特征:若某列浮点数的唯一值数量超过样本量的 20%,则视为连续(避免将“编码为 1.0/2.0/3.0 的类别”误判为连续);若唯一值极少(如 < 5),则强制走离散路径。
更进一步,对于西瓜数据集中“敲声”这类语义上有序但存储为字符串的变量(如 'qingcui', 'zhuoxiang', 'chenmen'),项目提供了 preprocess_watermelon() 函数,将其映射为有序数值 [0, 1, 2],再交由连续路径处理——这模拟了 C4.5 中“有序离散特征”的处理思想,比简单 one-hot 编码更符合领域知识。
这种融合带来的直接好处是模型表达一致性:无论输入是鸢尾花还是西瓜,用户调用的都是同一个 DecisionTree(max_depth=3, criterion='gini') 接口,内部自动适配分裂策略。你不需要记住“鸢尾花用 ID3,西瓜用 C4.5”,只需要理解“我的数据里哪些列是连续的,哪些是离散的”,而这个判断逻辑已内置于训练流程中。
2.3 树结构封装哲学:TreeNode 不是容器,而是状态机
tree.py 中的 TreeNode 类,是我重构过程中最费思量的部分。最初版本它只是一个简单的字典式容器:{'feature': 2, 'threshold': 2.45, 'children': {...}}。但很快发现,这种设计在递归预测和可视化时极其脆弱——比如,当你要渲染树图时,需要知道某个节点是“内部节点”还是“叶子节点”,但容器本身不携带行为逻辑。
于是彻底重写为面向对象的状态机:
class TreeNode:
def __init__(self, feature_idx=None, threshold=None, is_leaf=False, class_label=None):
self.feature_idx = feature_idx # 分裂特征索引,叶子节点为 None
self.threshold = threshold # 分裂阈值,叶子节点为 None
self.children = {} # 字典:连续特征为 {0: left_node, 1: right_node};离散特征为 {'val1': node1, 'val2': node2}
self.is_leaf = is_leaf # 布尔标志,决定是否继续分裂
self.class_label = class_label # 叶子节点的预测类别
self.n_samples = 0 # 该节点覆盖的样本数(用于剪枝和可视化标注)
self.gini = 0.0 # 该节点的基尼不纯度(用于深度分析)
def predict(self, x):
"""单样本预测:递归向下,直到叶子节点"""
if self.is_leaf:
return self.class_label
# 连续特征:x[self.feature_idx] <= self.threshold → 左子树,否则右子树
if isinstance(self.threshold, (int, float)):
branch = 0 if x[self.feature_idx] <= self.threshold else 1
# 离散特征:直接匹配 x[self.feature_idx] 的值
else:
branch = x[self.feature_idx]
return self.children[branch].predict(x)
def to_dot_string(self, node_id=0):
"""生成 Graphviz DOT 语言字符串,支持递归展开"""
if self.is_leaf:
label = f"Class: {self.class_label}\\nSamples: {self.n_samples}"
return f'node{node_id} [label="{label}", shape=box];\n'
else:
feat_name = FEATURE_NAMES[self.feature_idx]
if isinstance(self.threshold, (int, float)):
label = f"{feat_name} <= {self.threshold:.2f}\\nSamples: {self.n_samples}"
left_id, right_id = node_id*2+1, node_id*2+2
dot_str = f'node{node_id} [label="{label}"];\n'
dot_str += f'node{node_id} -> node{left_id} [label="True"];\n'
dot_str += f'node{node_id} -> node{right_id} [label="False"];\n'
dot_str += self.children[0].to_dot_string(left_id)
dot_str += self.children[1].to_dot_string(right_id)
return dot_str
# 离散分支类似,略
这个设计让 TreeNode 成为一个自包含的决策单元。它不仅存储数据,还封装了预测行为(predict())、序列化行为(to_dot_string())、甚至未来可扩展的剪枝行为(prune_if_pure())。当你调用 root.predict(x) 时,不是在操作一堆松散的字典,而是在驱动一个状态机沿着决策路径自动流转。这种封装极大提升了代码的可维护性和可测试性——你可以单独实例化一个 TreeNode,传入虚拟数据,验证其 predict() 是否正确返回预期类别,而无需启动整个训练流程。
更重要的是,它为后续的深度分析埋下伏笔。iris_test2.py 中绘制的“深度-准确率曲线”,其横坐标 depth 并非简单地统计递归层数,而是通过 TreeNode.depth() 方法(在构建时注入)精确获取每个节点的实际深度;纵坐标 accuracy 的计算,则依赖于每个节点的 n_samples 和 class_label 字段,从而能绘制出“各深度下测试集覆盖率”等衍生指标。这种细粒度的状态记录,是扁平化字典结构永远无法提供的。
3. 核心细节解析与实操要点:从数据加载到树图生成的完整链路
3.1 数据预处理:鸢尾花与西瓜的差异化清洗策略
数据是决策树的血液,预处理的质量直接决定模型的上限。本项目对两个经典数据集采用了截然不同的清洗策略,这并非随意为之,而是深刻理解其数据特性的结果。
鸢尾花数据集(iris.data):表面看是“标准”数据集,但原始 UCI 版本存在隐藏陷阱。第一,它有 150 行,但最后一行常因换行符缺失而读取失败;第二,类别标签是字符串 'Iris-setosa',而我们的 TreeNode 预测返回整数索引,必须建立映射。因此 load_iris_data() 函数做了三件事:
1. 使用 pandas.read_csv(..., skip_blank_lines=True) 自动跳过空行;
2. 对 species 列执行 pd.Categorical().codes 编码,将 'setosa'→0, 'versicolor'→1, 'virginica'→2;
3. 最关键一步:对特征矩阵进行 StandardScaler 归一化?不。决策树对特征尺度完全不敏感,归一化反而会破坏原始物理意义(比如“花瓣长度 1.5 厘米”被缩放到 0.23,人无法直观理解)。所以这里只做 X = X.astype(np.float64) 类型强转,确保数值计算精度。
西瓜数据集(watermelon.data):这是周志华《机器学习》书中的经典示例,共 17 条记录,但原始文本格式混乱。例如,“纹理”列有 '清晰', '稍糊', '模糊',而“根蒂”列是 '蜷缩', '硬挺', '稍蜷'。这些中文字符串不能直接用于计算,必须转化为可排序、可比较的数值。preprocess_watermelon() 的处理逻辑如下:
- 对“色泽”“根蒂”“敲声”“纹理”“脐部”“触感”六个属性,分别定义有序映射字典:
python TEXT_TO_NUM = { '色泽': {'青绿': 0, '乌黑': 1, '浅白': 2}, '根蒂': {'蜷缩': 0, '稍蜷': 1, '硬挺': 2}, '敲声': {'清脆': 0, '浊响': 1, '沉闷': 2}, # ...其他类似 }
- 对“密度”“含糖率”两个连续属性,不做任何变换,保留原始浮点值;
- 特别注意:“好瓜”标签是 '是'/'否',这里不简单映射为 1/0,而是先转为布尔 True/False,再用 y.astype(int) 得到 1/0——这保证了后续基尼计算中 np.unique(y) 返回 [0, 1],而非 [False, True],避免类型混淆。
这种差异化处理体现了核心原则:预处理不是标准化流水线,而是领域知识的编码过程。鸢尾花是测量数据,保持原始尺度;西瓜是农业经验数据,需将农学术语转化为可计算序数。忽略这点,直接 LabelEncoder 一把梭,会导致“稍糊”纹理的编码值高于“清晰”,违背农业常识,模型自然学歪。
3.2 分裂质量评估:基尼不纯度的手动实现与数值稳定性保障
决策树的核心是“找最好的分裂”。所谓“最好”,数学上定义为使子节点纯度提升最大的分裂。本项目选用基尼不纯度(Gini Impurity)而非信息熵(Entropy),原因很实在:计算更快,且对二分类问题效果几乎无差别。其公式为:
$$
\text{Gini}(D) = 1 - \sum_{k=1}^K p_k^2
$$
其中 $p_k$ 是数据集 $D$ 中第 $k$ 类样本的比例。
_gini_impurity(y) 函数的手动实现看似简单,但藏着两个易被忽视的坑:
def _gini_impurity(y):
if len(y) == 0:
return 0.0
# 坑1:直接 np.unique(y, return_counts=True) 在 y 为整数时极快,
# 但如果 y 是字符串或 object 类型,速度暴跌 10 倍
classes, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1.0 - np.sum(probs ** 2)
坑1:数据类型陷阱。np.unique() 对 int64 数组是 O(n log n),但对 object 数组(如未编码的 'setosa' 字符串)是 O(n²),因为要逐字符比较。这就是为何预处理中必须将类别转为整数编码——不是为了节省内存,而是为了加速纯度计算。在 iris_test2.py 中,一次全量训练要调用 _gini_impurity() 上万次(每个候选分裂点都要算左右子集纯度),类型错误会让运行时间从 2 秒飙升到 30 秒。
坑2:浮点精度灾难。当 probs 中某个概率极小(如 1e-15),probs ** 2 会变成 1e-30,np.sum() 可能因精度丢失返回 1.0000000000000002,导致 1.0 - sum 得到 -2e-16 ——一个负的基尼值!这在后续 gain = gini_parent - weighted_gini_children 中会引发连锁错误。解决方案是在返回前加一道钳制:
gini = 1.0 - np.sum(probs ** 2)
return max(0.0, min(1.0, gini)) # 强制 [0, 1] 区间
这个 max/min 钳制看似简单,却是无数人在手写算法时踩过的坑。它不改变数学本质,只是为浮点世界筑了一道安全堤坝。
3.3 连续特征最优切分:暴力搜索的工程优化技巧
对连续特征找最优切分点,理论方案是遍历所有相邻值中点,但工程上必须优化。以西瓜数据集的“密度”为例,共 17 个样本,排序后有 16 个中点候选。暴力搜索没问题;但若换成 10 万样本的信贷数据,10 万-1 个候选点,每个点都要计算左右子集纯度,复杂度 O(n²),不可接受。
本项目采用三级优化策略:
第一级:候选点压缩。不遍历所有中点,而是按 step_size 采样。_find_best_threshold_for_continuous_feature() 中默认 step_size=1(即全遍历),但预留了接口:
# 若数据量大,可设 step_size=10,跳过 90% 候选点
candidates = sorted_vals[::step_size]
这牺牲了理论最优性,但实践中,由于基尼函数通常是平滑的单峰函数,跳点采样找到的阈值与全局最优相差极小,而速度提升显著。
第二级:提前终止。在遍历候选点时,若发现当前 best_gain 已达到理论最大值(如二分类中 gain_max = gini_parent),立即跳出循环。因为 gain = gini_parent - weighted_gini_children ≤ gini_parent,等号成立意味着子集完全纯净,无需再找。
第三级:向量化计算。关键瓶颈在 for t in candidates: left_mask = (X_col <= t); gini_left = _gini_impurity(y[left_mask])。这里 left_mask 每次都新建布尔数组,开销大。优化为:
# 预先计算排序后的 y_sorted,利用 cumsum 加速
y_sorted = y[np.argsort(X_col)] # 按 X_col 排序 y
cumsum_left = np.cumsum(y_sorted == 0) # 假设类别 0 的累计频次
# 则第 i 个候选点左侧的类别 0 比例 = cumsum_left[i] / (i+1)
此法将内层循环从 O(n) 降到 O(1),整体复杂度从 O(n²) 降至 O(n log n)。
这些优化不是炫技,而是让手写代码具备真实可用性。当你在 watermelon_test.py 中看到 Found best split for 'density' at threshold=0.382, gain=0.123 这行输出时,背后是经过三次工程打磨的高效搜索。
3.4 决策树可视化:从 .dot 文本到 PNG 图像的全链路控制
可视化是本项目的一大亮点,但其价值远不止“好看”。一张好的决策树图,是模型可解释性的终极载体。iris_test2.py 生成的 iris_tree.dot 文件,内容如下节选:
digraph Tree {
node [shape=box, style="filled", color="black"] ;
0 [label="petal length <= 2.45\\nsamples = 150\\ngini = 0.667"] ;
1 [label="petal width <= 1.75\\nsamples = 50\\ngini = 0.000"] ;
2 [label="Class: 0\\nsamples = 50"] ;
3 [label="petal length <= 4.85\\nsamples = 100\\ngini = 0.500"] ;
...
0 -> 1 [label="True"] ;
0 -> 3 [label="False"] ;
1 -> 2 ;
...
}
这段文本的生成逻辑,全部封装在 TreeNode.to_dot_string() 中。但真正让它“活起来”的,是 graphviz 和 pydot 的协同工作。这里有个关键细节:pydot 本身不渲染图像,它只是 graphviz 的 Python 封装,真正的渲染引擎是 dot.exe(Windows)或 /usr/local/bin/dot(macOS)。因此,project_operation_guide.md 中强调:
“请务必确认
dot命令可在终端直接运行。Windows 用户安装 Graphviz 后,必须将C:\Program Files\Graphviz2.38\bin(路径依版本而定)添加到系统环境变量 PATH。验证方法:打开 CMD,输入dot -V,应返回dot - graphviz version 2.38.0。”
很多用户卡在这一步,报错 pydot.InvocationException: GraphViz's executable not found。根本原因不是没装 Graphviz,而是 pydot 找不到 dot 可执行文件。解决方案有二:
- 推荐:按指南配置 PATH,一劳永逸;
- 备选:在代码中硬编码路径,graph = pydot.Dot(graph_type='digraph', prog='dot') 改为 graph = pydot.Dot(graph_type='digraph', program='C:/Program Files/Graphviz2.38/bin/dot.exe')。
更进一步,iris_test1.py 的二维决策边界图,其技术难点不在绘图,而在网格生成。它不是简单画一条线,而是对特征平面进行高密度采样:
# 定义网格步长,太粗(0.1)边界锯齿,太细(0.01)内存爆炸
h = 0.02
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # 对每个网格点预测
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.RdYlBu)
这里的 h=0.02 是经验值:在鸢尾花数据范围(萼片长 4.3~7.9,宽 2.0~4.4)内,生成约 (7.9-4.3)/0.02 × (4.4-2.0)/0.02 ≈ 180×120 = 21600 个点,预测耗时 0.3 秒,图像足够平滑。若设 h=0.005,点数暴增至 345600,内存占用翻倍,而视觉提升微乎其微。这种“够用就好”的工程权衡,正是资深从业者与新手的本质区别。
4. 实操过程与核心环节实现:三份测试脚本的逐层递进
4.1 iris_test1.py:双特征二维边界可视化——理解决策树的几何本质
iris_test1.py 是项目的入门钥匙,它只用鸢尾花的前两个特征(萼片长度、萼片宽度),训练一棵浅层决策树,并绘制决策边界。这看似简单,却是理解决策树“轴平行分割”特性的最佳入口。
执行流程如下:
1. 数据加载与筛选:X, y = load_iris_data() 后,取 X = X[:, :2],即只保留 sepal length 和 sepal width;
2. 模型训练:dt = DecisionTree(max_depth=3),调用 dt.fit(X, y);
3. 网格预测:如前所述,生成 xx, yy 网格,对每个点 (x, y) 调用 dt.predict([x, y]);
4. 绘图:用 contourf 填充背景色,scatter 绘制原始数据点,plt.xlabel('Sepal Length (cm)') 添加坐标轴标签。
这张图的价值,在于它把抽象的“if-else”逻辑变成了可视的几何分割。你会看到:
- 整个平面被一系列竖直或水平的直线切割成矩形区域;
- 每个矩形区域被涂上红、蓝、黄三色之一,代表该区域内多数样本的类别;
- 原始数据点散落在这些矩形中,有些点位于边界线上——这说明它们恰好落在某个分裂阈值上,模型对其预测可能不稳定(这也是为何实际部署中要加置信度阈值)。
提示:运行此脚本时,注意观察
max_depth=1和max_depth=3的对比。深度为 1 时,只有一条分割线,平面被切成两半;深度为 3 时,最多有 7 个叶节点(2³-1),平面被切成 7 个矩形。这直观展示了“深度”如何控制模型复杂度——深度越大,分割越细,拟合越强,但也越容易过拟合。
4.2 iris_test2.py:四特征全维度建模与深度分析——掌握模型调优的科学方法
如果说 iris_test1.py 是“看见”,那么 iris_test2.py 就是“测量”。它用全部 4 个特征训练模型,并提供三项深度分析能力:树结构图、深度-准确率曲线、节点统计报告。
其核心函数 analyze_depth_impact() 的实现逻辑如下:
def analyze_depth_impact(X_train, X_test, y_train, y_test, max_depth_range=range(1, 11)):
results = []
for depth in max_depth_range:
dt = DecisionTree(max_depth=depth)
dt.fit(X_train, y_train)
acc_train = dt.score(X_train, y_train)
acc_test = dt.score(X_test, y_test)
# 关键:提取树的统计信息
stats = dt.get_tree_stats() # 返回 {depth: {'n_nodes': ..., 'avg_depth': ...}}
results.append({
'depth': depth,
'train_acc': acc_train,
'test_acc': acc_test,
'n_nodes': stats['n_nodes'],
'leaf_nodes': stats['leaf_nodes']
})
# 绘制曲线
plt.plot([r['depth'] for r in results], [r['test_acc'] for r in results], 'o-', label='Test Accuracy')
plt.xlabel('Max Depth')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('figure/depth_vs_accuracy.png')
这张 depth_vs_accuracy.png 图,是模型调优的黄金准则。你会发现:
- 当 depth=1 时,准确率很低(约 60%),因为分割太粗糙;
- depth=3 到 depth=5 时,测试准确率稳定在 95% 左右,这是最佳区间;
- depth≥6 后,训练准确率持续上升(逼近 100%),但测试准确率开始下降——典型的过拟合信号。
此时,你应该果断将 max_depth 设为 5。这比盲目调参高效得多。iris_test2.py 还会生成 iris_tree.dot,用 Graphviz 渲染为 iris_tree.png。打开这张图,你能看到:
- 根节点分裂特征是 petal length,阈值 2.45,这印证了植物学常识:花瓣长度是区分鸢尾花最显著的形态特征;
- 第二层,左子树(petal length ≤ 2.45)直接判定为 Class 0(山鸢尾),因为该区域样本纯度已达 100%;
- 右子树继续分裂,最终在 petal width 和 sepal width 上完成精细区分。
这种“从数据中自动发现领域知识”的能力,正是决策树的魅力所在。
4.3 watermelon_test.py:西瓜数据连续属性切分实战——处理真实业务数据的范式
watermelon_test.py 是项目的压轴戏,它处理的是更贴近真实业务的数据:西瓜好坏预测。17 条记录虽少,但包含了决策树处理连续值的所有挑战。
执行步骤:
1. 数据加载:X, y = preprocess_watermelon(),得到 17×8 矩阵(6 个属性 + 密度 + 含糖率);
2. 最优切分搜索:对每个连续属性(密度、含糖率),调用 _find_best_threshold_for_continuous_feature(),输出类似:
Best split for 'density': threshold=0.382, gain=0.123 Best split for 'sugar_ratio': threshold=0.161, gain=0.098
3. 树构建与打印:dt.fit(X, y) 后,调用 dt.print_tree(),输出缩进式文本树:
[Root] density <= 0.382 [Node] sugar_ratio <= 0.161 [Leaf] Class: 0 (Bad) [Node] density <= 0.421 [Leaf] Class: 1 (Good)
这个输出比图形更珍贵——它是一份可审计的决策日志。业务人员可以拿着它问:“为什么密度≤0.382 就是坏瓜?”答案就在训练数据中:所有密度≤0.382 的西瓜样本,90% 都是坏瓜。这种“数据驱动的规则提炼”,正是决策树在风控、医疗等领域的核心价值。
注意:西瓜数据集样本量极小(17 条),因此
max_depth应设为None或较大值(如 5),否则树太浅无法学习。但这也提醒我们:决策树在小样本上易过拟合,实际业务中必须配合交叉验证或剪枝。
5. 常见问题与排查技巧实录:那些文档里不会写的坑
5.1 Graphviz 渲染失败的五大原因及精准修复
Graphviz 是本项目可视化的心脏,但也是报错重灾区。根据实测,90% 的渲染失败可归为以下五类:
| 错误现象 | 根本原因 | 修复方案 |
|---|---|---|
pydot.InvocationException: GraphViz's executable not found | pydot 找不到 dot.exe | Windows:将 C:\Program Files\Graphviz2.38\bin 加入 PATH;macOS:brew install graphviz 后确认 /usr/local/bin/dot 存在;Linux:sudo apt-get install graphviz |
Error: <stdin>: syntax error in line 1 near 'digraph' | .dot 文件编码为 UTF-8 with BOM,dot 解析失败 | 用 VS Code 打开 iris_tree.dot,右下角点击编码 → Save with Encoding → 选择 UTF-8(无 BOM) |
Warning: No fonts could be loaded | dot 缺少中文字体支持 | Windows:复制 simhei.ttf 到 C:\Program Files\Graphviz2.38\fonts\;macOS:cp /System/Library/Fonts/PingFang.ttc /usr/local/share/fonts/,然后 sudo fc-cache -fv |
Layout was not done. Unrecognized node shape: box | dot 版本过低(< 2.38)不支持 shape=box | 升级 Graphviz 至最新版,官网下载链接在 project_operation_guide.md 中 |
Segmentation fault (core dumped) | pydot 版本与 graphviz 不兼容 | 降级 pydot:pip install pydot==1.4.2(经测试最稳定) |
实操心得:首次运行
iris_test2.py前,务必先在终端执行dot -Tpng -o test.png test.dot(test.dot是一个最简文件),验证dot命令本身是否正常。这能快速定位是环境问题还是代码问题。
5.2 matplotlib 中文乱码的终极解决方案
iris_test1.py 和 iris_test2.py 中的图表若出现方块乱码,是因为 matplotlib 默认字体不支持中文。网上常见方案是修改 matplotlibrc,但本项目采用更鲁棒的代码内嵌方案:
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans'] # 优先级顺序
plt.rcParams['axes.unicode_minus'] = False # 解决负号 '-' 显示为方块的问题
关键是 SimHei(Windows 黑体)放在首位,Arial Unicode MS(macOS)次之,DejaVu Sans(Linux)垫底。这样一份代码,跨平台都能显示中文。project_operation_guide.md 中特别注明:“若 Windows 无 SimHei,可下载 msyh.ttc(微软雅黑)放入 C:\Windows\Fonts\,然后将 SimHei 替换为 Microsoft YaHei”。
5.3 “ValueError: Input contains NaN, infinity or a value too large for dtype(‘float64’)” 的溯源排查
此错误通常出现在 dt.fit(X, y) 时,表面是数据含 NaN,但根源往往在预处理。排查链路如下:
1. 检查 X 是否为空:print("X shape:", X.shape, "X dtype:", X.dtype);
2. 检查 X 是否含 NaN:print("Any NaN in X:", np.isnan(X).any());
3. 若为 True,追溯到 load_iris_data(),发现原始 iris.data 中有缺失值标记 '?',需在 pandas.read_csv() 中加 na_values='?' 参数;
4. 若 X 正常,检查 y:print("y unique:", np.unique(y)),若输出 ['?'],说明类别列也有缺失,需同步处理。
踩坑记录:曾有用户将
iris.data用 Excel 打开后另存为 CSV,Excel 会自动将科学计数法(如1e-5)转为0.00001,但某些版本会引入不可见字符,导致astype(float)失败。解决方案:永远用pandas直接读取原始.data文件,不经过 Excel 中转。
5.4 深度-准确率曲线异常平坦的诊断指南
理想曲线应在某深度后测试准确率下降,但若你看到一条近乎水平的直线(如 depth=1 到 depth=10,测试准确率始终 95%±0.5%),说明模型未学到深度相关的复杂模式。可能原因:
- 数据泄露:X_train 和 X_test 划分错误,比如用了 train_test_split(X, y, test_size=0.2, shuffle=False),而鸢尾花数据是按类别顺序排列的,导致测试集全是后 30 个样本(virginica),训练集缺少该类别;
- 特征冗余:4 个特征中,petal length 和 petal width 高度相关(相关系数 0.96),模型仅用一个特征就可达到饱和性能,增加深度无收益;
- 剪枝过强:min_samples_split=10 设置过大,即使 max_depth=10,实际树深也止步于 3。
诊断方法:在 analyze_depth_impact() 中加入 print(f"Depth {depth}: n_nodes={dt.get_n_nodes()}, leaf_nodes={dt.get_n_leaf_nodes()}"),若 n_nodes 始终为 3(根+2叶),说明分裂被提前终止。
5.5 “ModuleNotFoundError: No module named ‘pydot’” 的离线安装秘籍
公司内网环境常无法 pip install pydot。此时需离线安装:
1. 在有网机器上:pip download pydot graphviz,得到 pydot-1.4.2-py2.py3-none-any.whl 和 graphviz-0.20.1-py3-none-any.whl;
2. 复制到内网机,pip install --find-links ./ --no-index pydot;
3. 关键:graphviz 是纯 Python 包,但 pydot 依赖系统级 graphviz,所以内网机仍需单独安装 Graphviz 二进制(Windows MSI 安装包,macOS pkg 包)。
最后分享一个小技巧:若只想快速验证树结构,不必渲染 PNG。
iris_test2.py中注释掉graph.write_png(...),改为print(dt.root.to_dot_string()),直接在控制台查看 DOT 文本。这能绕过所有图形依赖,专注逻辑验证。
6. 项目延伸与个人体会:从手写决策树到算法工程师的成长路径
这个项目从构思到交付,历时 22 天,重写了 7 轮核心算法。最深的体会是:手写算法不是为了替代 sklearn,而是为了获得一种“算法直觉”。当你亲手实现 _information_gain(),你会对“为什么信息增益比基尼不纯度更适合 ID3”有切肤之痛;当你调试 watermelon_test.py 中“敲声”特征的切分点,你会理解农业专家为何说“清脆声的西瓜密度普遍更高”——这种直觉,是调参调不出来的。
项目后续可自然延伸三个方向:
- 剪枝增强:加入 CCP(代价复杂度剪枝),用 sklearn.tree.export_text() 的逻辑反推 ccp_alpha,让树在保持精度前提下更简洁;
- 特征重要性量化:在 _best_split() 中记录每个特征的累计增益,最终归一化为重要性分数,输出 feature_importance.png;
- Web 可视化:用 Flask 搭建简易界面,上传 CSV 数据,实时生成决策树图和边界图,让业务方也能玩转算法。
但比技术延伸更重要的是认知升级。我曾经以为,算法工程师的核心竞争力是“调得准”,现在明白,真正的护城河是“说得清”。当产品经理问“为什么这个客户被拒贷?”,你能打开决策树图,指着 income < 5000 和 debt_ratio > 0.6 两条路径,清晰解释模型逻辑——这种能力,远胜于把准确率从 92% 提升到 92.3%。
所以,如果你正在学习机器学习,请一定动手写一次决策树。不要怕慢,不要怕错。当你的 tree.py 第一次成功预测出鸢尾花类别,当 iris_test1.py 的边界图在屏幕上亮起,那种“我创造了智能”的震撼,会成为你工程师生涯中最明亮的灯塔。它提醒你:AI 不是魔法,而是人类用逻辑编织的精密织物;而你,正亲手握住那根最基础的丝线。
简介:一套开箱即用的纯Python决策树实现,不依赖sklearn内置Tree类,完整复现ID3/C4.5核心逻辑。支持两类典型任务:用鸢尾花全部4个特征做多分类(含二维决策边界图、全维度建模、树结构图生成、不同深度下的准确率变化曲线),以及对西瓜数据集进行连续属性切分与好坏二分类(输出可读树结构、.dot文件及渲染图)。所有脚本含逐行中文注释,运行iris_test1.py得双特征分类边界图,iris_test2.py输出四特征模型+决策树图+深度-精度关系图,watermelon_test.py完成西瓜数据预处理、最优分割点计算与树形打印。data目录内置标准iris.data和watermelon.data;figure目录存放12张实测图(含6张png+6张dot源文件);项目操作说明.md详细列出Python 3.7+环境配置、graphviz安装步骤、pydot调用方法、单文件执行顺序及常见报错修复方案;requirements.txt声明scikit-learn、matplotlib、graphviz、pydot四项依赖,无其他第三方要求。


被折叠的 条评论
为什么被折叠?



