机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

2017年12月14日22:21:35 89 26,604 °C
摘要

本篇文章将会讲解CART算法的实现和树的剪枝方法,通过测试不同的数据集,学习CART算法和树剪枝技术。

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

一、前言

本篇文章将会讲解CART算法的实现和树的剪枝方法,通过测试不同的数据集,学习CART算法和树剪枝技术。

二、将CART(Classification And Regression Trees)算法用于回归

在之前的文章,我们学习了决策树的原理和代码实现,使用使用决策树进行分类。决策树不断将数据切分成小数据集,直到所有目标标量完全相同,或者数据不能再切分为止。决策树是一种贪心算法,它要在给定时间内做出最佳选择,但不关心能否达到全局最优。

1、ID3算法的弊端

回忆一下,决策树的树构建算法是ID3。ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。也就是说,如果一个特征有4种取值,那么数据将被切分成4份。一旦按某特征切分后,该特征在之后的算法执行过程中将不会再起作用,所以有观点认为这种切分方式过于迅速。

除了切分过于迅速外,ID3算法还存在另一个问题,它不能直接处理连续型特征。只有事先将连续型特征离散化,才能在ID3算法中使用。但这种转换过程会破坏连续型变量的内在特性。

2、CART算法

与ID3算法相反,CART算法正好适用于连续型特征。CART算法使用二元切分法来处理连续型变量。而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。

CART算法有两步:

  • 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量大;自上而下从根开始建立节点,在每个节点处要选择一个最好的属性来分裂,使得子节点中的训练集尽量的纯。不同的算法使用不同的指标来定义"最好":
  • 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准。

决策树剪枝我们先不管,我们看下决策树生成。

在决策树的文章中,我们先根据信息熵的计算找到最佳特征切分数据集构建决策树。CART算法的决策树生成也是如此,实现过程如下:

  • 使用CART算法选择特征
  • 根据特征切分数据集合
  • 构建树

3、根据特征切分数据集合

我们先找软柿子捏,CART算法这里涉及到算法,实现起来复杂些,我们先挑个简单的,即根据特征切分数据集合。编写代码如下:

运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

我们先创建一个单位矩阵,然后根据切分规则,对数据矩阵进行切分。可以看到binSplitDataSet函数根据特定规则,对数据矩阵进行切分。

现在OK了,我们已经可以根据特征和特征值对数据进行切分了,mat0存放的是大于指定特征值的矩阵,mat1存放的是小于指定特征值的矩阵。接下来,我们就看看如何使用CART算法选择最佳分类特征。

4、CART算法

假设X与Y分别为输入和输出变量,并且Y是连续变量,给定训练数据集:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

其中,D表示整个数据集合,n为特征数。

一个回归树对应着输入空间(即特征空间)的一个划分以及在划分的单元上的输出值。假设已将输入空间划分为M个单元R1,R2,...Rm,并且在每个单元Rm上有一个固定的输出值Cm,于是回归树模型可表示为:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

这样就可以计算模型输出值与实际值的误差:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

我们希望每个单元上的Cm,可以是的这个平方误差最小化。易知,当Cm为相应单元的所有实际值的均值时,可以到最优:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

那么如何生成这些单元划分?

假设,我们选择变量 xj 为切分变量,它的取值 s 为切分点,那么就会得到两个区域:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

当j和s固定时,我们要找到两个区域的代表值c1,c2使各自区间上的平方差最小:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

前面已经知道c1,c2为区间上的平均:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

那么对固定的 j 只需要找到最优的s,然后通过遍历所有的变量,我们可以找到最优的j,这样我们就可以得到最优对(j,s),并得到两个区间。

这样的回归树通常称为最小二乘回归树(least squares regression tree)。

上述过程表示的算法步骤为:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

除此之外,我们再定义两个参数,tolS和tolN,分别用于控制误差变化限制和切分特征最少样本数。这两个参数的意义是什么呢?就是防止过拟合,提前设置终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作,在下一小节会进行进一步讲解。

老规矩,先看下我们的测试数据集。

数据集下载地址:数据集下载

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

如上图所示,数据是2维的。先看下数据的分布情况,编写代码如下:

运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,这是一个很简单的数据集,我们先利用这个数据集测试我们的CART算法。

现在,编写代码如下:

运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,切分的最佳特征为第1列特征,最佳切分特征值为0.48813,这个特征值怎么选出来的?就是根据误差估计的大小,我们选择的这个特征值可以使误差最小化。

切分的特征和特征值我们已经选择好了,接下来就是利用选出的这两个变量创建回归树了。

创建方法很简单,我们根据切分的特征和特征值切分出两个数据集,然后将两个数据集分别用于左子树的构建和右子树的构建,直到无法找到切分的特征为止。因此,我们可以使用递归实现这个过程,编写代码如下:

运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

从上图可知,这棵树只有两个叶结点。

我们换一个复杂一点的数据集,分段常数数据集。

数据集下载地址:数据集下载

先看下数据:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

第一列的数据都是1.0,为了可视化方便,我们将第1列作为x轴数据,第2列作为y轴数据。对数据进行可视化,编写代码如下:

运行结果如图下所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,这个数据集是分段的。我们针对此数据集创建回归树。代码同上,运行结果如下图所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,该数的结构中包含5个叶结点。

现在为止,已经完成回归树的构建,但是需要某种措施来检查构建过程是否得当。这个技术就是剪枝(tree pruning)技术。

三、树剪枝

一棵树如果结点过多,表明该模型可能对数据进行了“过拟合”。

通过降低树的复杂度来避免过拟合的过程称为剪枝(pruning)。上小节我们也已经提到,设置tolS和tolN就是一种预剪枝操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。本节将分析后剪枝的有效性,但首先来看一下预剪枝的不足之处。

1、预剪枝

预剪枝有一定的局限性,比如我们现在使用一个新的数据集。

数据集下载地址:数据集下载

用上述代码绘制数据集看一下:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,对于这个数据集与我们使用的第一个数据集很相似,但是区别在于y的数量级差100倍,数据分布相似,因此构建出的树应该也是只有两个叶结点。但是我们使用默认tolS和tolN参数创建树,你会发现运行结果如下所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,构建出的树有很多叶结点。产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平均值,或许也能得到仅有两个叶结点组成的树:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,将参数tolS修改为10000后,构建的树就是只有两个叶结点。然而,显然这个值,需要我们经过不断测试得来,显然通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。因为对于一个很多维度的数据集,你也不知道构建的树需要多少个叶结点。

可见,预剪枝有很大的局限性。接下来,我们讨论后剪枝,即利用测试集来对树进行剪枝。由于不需要用户指定参数,后剪枝是一个更理想化的剪枝方法。

2、后剪枝

使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶结点,用测试集来判断这些叶结点合并是否能降低测试集误差。如果是的话就合并。

为了演示后剪枝,我们使用ex2.txt文件作为训练集,而使用的新数据集ex2test.txt文件作为测试集。

测试集下载地址:数据集下载

现在我们使用ex2.txt训练回归树,然后利用ex2test.txt对回归树进行剪枝。我们需要创建三个函数isTree()、getMean()、prune()。其中isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果。换句话说,该函数用于判断当前处理的结点是否是叶结点。第二个函数getMean()是一个递归函数,它从上往下遍历树直到叶结点为止。如果找到两个叶结点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值)。而第三个函数prune()则为后剪枝函数。编写代码如下:

运行结果如下如所示:

机器学习实战教程(十三):树回归基础篇之CART算法与树剪枝

可以看到,树的大量结点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。

现在,可能你会问了,这叶结点只是简单的数值。这也没有拟合数据啊?回归树到底啥样啊?别急,下篇文章继续讲解。

四、总结

  • CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。
  • 一颗过拟合的树常常十分复杂,剪枝技术的出现就是为了解决这个问题。两种剪枝方法分别是预剪枝和后剪枝,预剪枝更有效但需要用户定义一些参数。
  • 下篇文章将继续讲解回归树。
  • 如有问题,请留言。如有错误,还望指正,谢谢!

 

PS: 如果觉得本篇本章对您有所帮助,欢迎关注、评论、赞!

本文出现的所有代码和数据集,均可在我的github上下载,欢迎Follow、Star:点击查看

参考资料:

  • [1] 机器学习实战第八章内容
  • [2] 统计学习方法第五章内容
weinxin
微信公众号
分享技术,乐享生活:微信公众号搜索「JackCui-AI」关注一个在互联网摸爬滚打的潜行者。
君子和而不同,小人同而不和。--- 孔子
Jack Cui

发表评论取消回复

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

目前评论:89   其中:访客  51   博主  38

    • avatar biancheng 来自天朝的朋友 火狐浏览器 Windows 7 广东省深圳市 电信 2

      博主,机器学习的啥时候接着更?

        • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 辽宁省沈阳市 联通

          @biancheng 年前还会更一篇,现在主要看深度学习的东西呢。年后有时间,把剩下的继续更完。

        • avatar frey 来自天朝的朋友 谷歌浏览器 Windows 10 重庆市沙坪坝区 电信 3

          博主,问下CART决策树,每划分一次,会消耗掉一个特征吗,比如我第一次划分身高大于1.7,后几次还有可能划分身高大于1.5吗。

            • avatar frey 来自天朝的朋友 谷歌浏览器 Windows 10 重庆市沙坪坝区 电信 3

              @frey 用它做分类的时候

              • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 7 辽宁省沈阳市 东北大学四舍(女生)

                @frey 用掉的特征,在后续分类中,不会再使用的。

              • avatar 小角色 来自天朝的朋友 谷歌浏览器 Windows 7 山东省济南市 济南大学 1

                CART算法没有看懂,博主可以举个例子吗?

                  • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 辽宁省沈阳市 联通

                    @小角色 CART算法 可以看下书。

                  • avatar Drew 来自天朝的朋友 火狐浏览器 Windows 10 湖南省长沙市 电信 2

                    博主,请问一下这个是怎么解决的?第二个代码块中加载数据函数loadDataSet()中运行时会出现如下错误,ValueError: could not convert string to float: ‘0.036098 0.155096’,我打印了第一个curLine = [‘0.036098 0.155096’]

                      • avatar Jack Cui Admin 这家伙可能用了岛国的代理 谷歌浏览器  Android 7.1.1 MIX 2 Build/NMF26X 日本 东京都渋谷区GMO互联网公司

                        @Drew 提示已经很明显了,你解析有问题

                          • avatar Drew 来自天朝的朋友 火狐浏览器 Windows 10 湖南省长沙市 电信 2

                            @Jack Cui 恩恩,我知道,不过我是拿博主提供的代码编译的,然后报错了,所以才想问下

                        • avatar Justin 来自天朝的朋友 谷歌浏览器 Mac OS X 10_13_4 北京市 电信 1

                          博主,机器学习相关的文章还会继续更新吗?还是开始转神经网络了?

                            • avatar Jack Cui Admin 这家伙可能用了岛国的代理 谷歌浏览器  Android 7.1.1 MIX 2 Build/NMF26X 日本 东京都渋谷区GMO互联网公司

                              @Justin 我想等有时间了,把后续的继续补全。

                                • avatar Justin 来自天朝的朋友 谷歌浏览器 Mac OS X 10_13_4 北京市 电信 1

                                  @Jack Cui 嗯,期待更新

                              • avatar aspire 来自天朝的朋友 火狐浏览器 Windows 7 吉林省长春市 吉林大学 1

                                写得真是太棒了,博主去哪个公司实习了?

                                  • avatar Jack Cui Admin 来自天朝的朋友 火狐浏览器 Windows 7 北京市 百度网讯科技联通节点

                                    @aspire 呃,百度搬砖中。

                                  • avatar 小张 来自天朝的朋友 谷歌浏览器 Windows 7 上海市 鹏博士长城宽带 2

                                    请问 先生现在是什么工作啊? 小弟开年到现在一直在学习机器学习方面,只到最近把你这机器学习10篇文章全部学会了,感觉非常实在的东西,很感谢你,比起之前看的视频课程有用多了,想问一下如果在机器学习项目经验方面,就会这些的话,就业时候公司会要我吗?算法和数学基础都有的,另外后面楼主还提供些实际工作中的项目让我们参考参考吗?

                                      • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 联通

                                        @小张 我现在是视觉算法岗,感谢支持。机器学习想积累经验,可以试试参加各项机器学习相关的比赛,例如阿里天池,对于学生来讲,这是一个不错增加实战经验的方法。

                                      • avatar 小张 来自天朝的朋友 谷歌浏览器 Windows 7 上海市 鹏博士长城宽带 2

                                        请问下一篇文章 回归树 什么时候出来啊?

                                          • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 百度网讯科技联通节点

                                            @小张 近期必更新,给自己立flag。

                                              • avatar Quester 来自天朝的朋友 谷歌浏览器 Windows 10 四川省成都市 四川大学 0

                                                @Jack Cui 博主你立的flag啥时候能实现啊! 找不到像你写的这么好的很难受

                                                  • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器  Android 8.0.0 MIX 2 Build/OPR1.170623.027 辽宁省沈阳市 联通GSM/WCDMA/LTE共用出口

                                                    @Quester 惭愧惭愧啊,最近太忙了。这周末我先写着,争取八月前出一篇先。

                                                    • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 百度网讯科技联通节点

                                                      @Quester 来打自己脸了,文章写一半没写完,准备秋招呢,后面找好了,应该能出了。

                                                        • avatar Jack curry 来自天朝的朋友 谷歌浏览器 Windows 10 山西省 电信 4

                                                          @Jack Cui 博主牛逼看来百度稳稳的。

                                                  • avatar 小张 来自天朝的朋友 谷歌浏览器 Windows 7 上海市 鹏博士长城宽带 2

                                                    def regErr(dataSet):
                                                    #目标变量的方差
                                                    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
                                                    请问:最后一列的方差,不应该是 除以 行数吗,为什么是乘, 还是说这个方差的计算是 所有列各自的方差的和 的意思啊
                                                    还有就是最小二次回归,整体逻辑还不是很懂;为什么不停的吧数据分成2类呢? 又不是分类, 还有就是新数据怎么预测呢?求楼主悉心指导呀

                                                      • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 百度网讯科技联通节点

                                                        @小张 求的是总方差,单独运行np.var看看你就知道了。就是去拟合一个函数,让它最接近真实数据。新数据预测,学了这些章了应该知道啊,就是给模型送入数据就好了,按照训练的数据格式。

                                                      • avatar tilo 来自天朝的朋友 谷歌浏览器 Windows 7 北京市 移动 2

                                                        博主,一直关注你的文章~
                                                        问个问题望解答。这篇文章里最开始是进行了特征切分,那么在ID3里面,我如果也做了特征切分,然后通过计算信息增益的方式选择特征节点,那么理论上也可以处理连续值?因为你之前决策树文章里面,用sklearn生成决策树的时候,我看里面条件是<=0.5之类的,应该也是做了特征的切分吧,如果sklearn用ID3来算,应该也是会做这个操作的咯

                                                          • avatar tilo 来自天朝的朋友 谷歌浏览器 Windows 7 北京市 移动 2

                                                            @tilo 我好像傻了,,连续变量都已经划分出特征和特征值了,我为啥还想着要用信息增益,,