[ 统计学习  ]

统计学习方法:决策树

决策树可以用于分类,也可以用于回归,本文介绍的决策树主要是分类决策树

一、简述

分类决策树模型是基于特征对实例进行分类的树形结构。决策树可以转换为一个 if-then 规则的集合,也可以看作是定义在特征空间划分上的类的条件概率分布

决策树学习旨在构建一个与训练数据拟合良好,并且复杂度较小的决策树。因为从所有可能的决策树中直接选取最优决策树是 NP 完全问题,所以现实中采用启发式方法学习次优的决策树。

决策树的学习算法包括 3 部分:特征选择树的生成树的剪枝,常用算法有 ID3C4.5CART

二、特征选择

特征选择的目的在于选择能够对训练数据进行分类的特征。选择的关键在于依照某个准则。常用准则有:

(1) 信息增益

为便于说明,首先给出信息论中熵和条件熵的定义。

随机变量 $X$ 的 $H(X)$ 定义为

随机变量 $X$ 给定的条件下随机变量 $Y$ 的条件熵 $H(Y\mid X)$ 定义为:

当熵和条件熵中的概率由数据估计(特别是极大似然估计)得到时,所对应得到的熵和条件熵分别为经验熵经验条件熵。此时,如果有 0 概率,令 $0\log 0=0$ .


定义:样本集合 $D​$ 对特征 $A​$ 的信息增益

其中

是数据集 $D$ 的。而

是数据集 $D$ 对特征 $A$ 的条件熵

注:上面的符号表示

信息增益也叫类与特征的互信息

由于经验熵 $H(D)$ 表示对数据集 $D$ 进行分类的不确定性,而经验条件熵 $H(D\mid A)$ 表示在特征 $A$ 给定的条件下对数据集 $D$ 进行分类的不确定性,则它们的差,即信息增益,就表示由于特征 $A$ 而使得对数据集 $D$ 分类的不确定性减少的程度。显然,信息增益越大,分类能力越强。

ID3 算法使用信息增益。

(2) 信息增益比

信息增益比的大小是相对于训练数据集而言的:如果训练集的经验熵比较大,则信息增益比也会偏大,反之亦然。所以为了获得一个通用的标准,信息增益比派上了用场。

定义:特征 $A​$ 对数据集 $D​$ 的信息增益比 $g_R(D, A)​$ 定义为

C4.5 算法使用信息增益比。

(3) 基尼指数 (Gini impurity)

分类问题中,假设有 $K$ 个类,每个类的概率为 $p_k$,则概率分布的基尼指数定义为

对给定的数据集 $D$,其基尼指数为

如果根据特征 $A$ 是否等于 $a$ 来将数据集分成两部分

则在特征 $A$ 的条件下,集合 $D$ 的基尼指数定义为

基尼指数同样度量概率分布的不确定性,和熵类似。CART 算法使用基尼指数。

三、决策树的生成

本节将详细介绍 ID3C4.5 两种算法的树生成过程, CART 将在后面单独介绍。

(1) ID3

(2) C4.5

四、决策树的剪枝

决策树的剪枝通过极小化决策树整体的损失函数来完成。

设树的叶结点的个数为 $\vert T\vert$,$t$ 是树 $T$ 的叶结点,该叶结点有 $N_t$ 个样本点,其中 $k$ 类的样本点有 $N_{tk}$ 个,$k=1,\dots,K$,$H_t(T)$ 为叶结点 $t$ 上的经验熵,$\alpha\ge 0$ 是参数,则决策树学习的损失函数可以定义为:

其中经验熵为

若将损失函数第一项写成

则损失函数为

上式中,$C(T)$ 表示模型对训练数据集的预测误差。考虑一种极端情况:每个叶结点里都仅有一个类。这种情况说明这个决策树对训练集将完全没有误差——每个样本都被分到它原本的类中。从公式上来看,就是每个 $H_t(T)$ 都是 0,最终误差 $C(T)$ 当然也是 0。而第二项 $\alpha\vert T\vert$ 是正则化项,用于控制模型复杂度:我们更倾向于复杂度较小的模型,这样的模型泛化能力更强。而 $\alpha$ 起到调节两者(训练误差和模型复杂度)平衡的作用。

可以看到,决策树生成只关心最小化训练集的预测误差,而决策树剪枝还考虑了减小模型复杂度,提高泛化能力。因此,有这样一句话:

决策树生成学习局部的模型,而决策树剪枝学习整体的模型。

下面是剪枝算法:

五、CART 算法

CART 模型意为“分类与回归树” (Classification and Regression Tree),既能用来分类,也能用来回归。

CART 假设决策树为二叉树,递归地二分每个特征。主要由两个步骤构成:

  1. 决策树生成:生成的决策树要尽可能大
  2. 决策树剪枝

5.1 CART 生成

生成回归树应用最小化平方误差的准则;生成分类树应用最小化基尼指数的准则。

回归树生成

分类树生成

5.2 CART 剪枝

CART 剪枝算法由两步组成:

  1. 首先从生成算法产生的决策树 $T_0$ 底部开始不断剪枝,直到 $T_0$ 的根结点,形成一个子树序列 ${T_0,\dots,T_n}$;
  2. 然后通过交叉验证法在独立的验证数据集上对子树序列进行测试,从中选择最优的子树。

上面的算法是李航《统计学习方法》中给出的,但我认为有点问题。这个网页中给出的算法应该更准确:Cost-Complexity Pruning#Algorithm

参考