机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

2017年11月8日18:42:19 103 39,989 °C
摘要

本文从Logistic回归的原理开始讲起,补充了书上省略的数学推导。本文可能会略显枯燥,理论居多,Sklearn实战内容会放在下一篇文章。自己慢慢推导完公式,还是蛮开心的一件事。

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

一、前言

本文从Logistic回归的原理开始讲起,补充了书上省略的数学推导。本文可能会略显枯燥,理论居多,Sklearn实战内容会放在下一篇文章。自己慢慢推导完公式,还是蛮开心的一件事。

二、Logistic回归与梯度上升算法

Logistic回归是众多分类算法中的一员。通常,Logistic回归用于二分类问题,例如预测明天是否会下雨。当然它也可以用于多分类问题,不过为了简单起见,本文暂先讨论二分类问题。首先,让我们来了解一下,什么是Logistic回归。

1、Logistic回归

假设现在有一些数据点,我们利用一条直线对这些点进行拟合(该线称为最佳拟合直线),这个拟合过程就称作为回归,如下图所示:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

Logistic回归是分类方法,它利用的是Sigmoid函数阈值在[0,1]这个特性。Logistic回归进行分类的主要思想是:根据现有数据对分类边界线建立回归公式,以此进行分类。其实,Logistic本质上是一个基于条件概率的判别模型(Discriminative Model)。

所以要想了解Logistic回归,我们必须先看一看Sigmoid函数 ,我们也可以称它为Logistic函数。它的公式如下:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

整合成一个公式,就变成了如下公式:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

下面这张图片,为我们展示了Sigmoid函数的样子。

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

z是一个矩阵,θ是参数列向量(要求解的),x是样本列向量(给定的数据集)。θ^T表示θ的转置。g(z)函数实现了任意实数到[0,1]的映射,这样我们的数据集([x0,x1,...,xn]),不管是大于1或者小于0,都可以映射到[0,1]区间进行分类。hθ(x)给出了输出为1的概率。比如当hθ(x)=0.7,那么说明有70%的概率输出为1。输出为0的概率是输出为1的补集,也就是30%。

如果我们有合适的参数列向量θ([θ0,θ1,...θn]^T),以及样本列向量x([x0,x1,...,xn]),那么我们对样本x分类就可以通过上述公式计算出一个概率,如果这个概率大于0.5,我们就可以说样本是正样本,否则样本是负样本。

举个例子,对于"垃圾邮件判别问题",对于给定的邮件(样本),我们定义非垃圾邮件为正类,垃圾邮件为负类。我们通过计算出的概率值即可判定邮件是否是垃圾邮件。

那么问题来了!如何得到合适的参数向量θ?

根据sigmoid函数的特性,我们可以做出如下的假设:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

式即为在已知样本x和参数θ的情况下,样本x属性正样本(y=1)和负样本(y=0)的条件概率。理想状态下,根据上述公式,求出各个点的概率均为1,也就是完全分类都正确。但是考虑到实际情况,样本点的概率越接近于1,其分类效果越好。比如一个样本属于正样本的概率为0.51,那么我们就可以说明这个样本属于正样本。另一个样本属于正样本的概率为0.99,那么我们也可以说明这个样本属于正样本。但是显然,第二个样本概率更高,更具说服力。我们可以把上述两个概率公式合二为一:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

合并出来的Loss,我们称之为损失函数(Loss Function)。当y等于1时,(1-y)项(第二项)为0;当y等于0时,y项(第一项)为0。为s了简化问题,我们对整个表达式求对数,(将指数问题对数化是处理数学问题常见的方法):

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

这个损失函数,是对于一个样本而言的。给定一个样本,我们就可以通过这个损失函数求出,样本所属类别的概率,而这个概率越大越好,所以也就是求解这个损失函数的最大值。既然概率出来了,那么最大似然估计也该出场了。假定样本与样本之间相互独立,那么整个样本集生成的概率即为所有样本生成概率的乘积,便可得到如下公式:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

其中,m为样本的总数,y(i)表示第i个样本的类别,x(i)表示第i个样本,需要注意的是θ是多维向量,x(i)也是多维向量。

综上所述,满足J(θ)的最大的θ值即是我们需要求解的模型。

怎么求解使J(θ)最大的θ值呢?因为是求最大值,所以我们需要使用梯度上升算法。如果面对的问题是求解使J(θ)最小的θ值,那么我们就需要使用梯度下降算法。面对我们这个问题,如果使J(θ) := -J(θ),那么问题就从求极大值转换成求极小值了,使用的算法就从梯度上升算法变成了梯度下降算法,它们的思想都是相同的,学会其一,就也会了另一个。本文使用梯度上升算法进行求解。

2、梯度上升算法

说了半天,梯度上升算法又是啥?J(θ)太复杂,我们先看个简单的求极大值的例子。一个看了就会想到高中生活的函数:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

来吧,做高中题。这个函数的极值怎么求?显然这个函数开口向下,存在极大值,它的函数图像为:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

求极值,先求函数的导数:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

令导数为0,可求出x=2即取得函数f(x)的极大值。极大值等于f(2)=4

但是真实环境中的函数不会像上面这么简单,就算求出了函数的导数,也很难精确计算出函数的极值。此时我们就可以用迭代的方法来做。就像爬坡一样,一点一点逼近极值。这种寻找最佳拟合参数的方法,就是最优化算法。爬坡这个动作用数学公式表达即为:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

其中,α为步长,也就是学习速率,控制更新的幅度。效果如下图所示:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

比如从(0,0)开始,迭代路径就是1->2->3->4->...->n,直到求出的x为函数极大值的近似值,停止迭代。我们可以编写Python3代码,来实现这一过程:

代码运行结果如下:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

结果很显然,已经非常接近我们的真实极值2了。这一过程,就是梯度上升算法。那么同理,J(θ)这个函数的极值,也可以这么求解。公式可以这么写:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

由上小节可知J(θ)为:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

sigmoid函数为:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

那么,现在我只要求出J(θ)的偏导,就可以利用梯度上升算法,求解J(θ)的极大值了。

那么现在开始求解J(θ)对θ的偏导,求解如下:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

其中:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

再由:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

可得:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

接下来,就剩下第三部分:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

综上所述:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

因此,梯度上升迭代公式为:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

知道了,梯度上升迭代公式,我们就可以自己编写代码,计算最佳拟合参数了。

三、Python3实战

1、数据准备

数据集已经为大家准备好,下载地址:

这就是一个简单的数据集,没什么实际意义。让我们先从这个简单的数据集开始学习。先看下数据集有哪些数据:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

这个数据有两维特征,因此可以将数据在一个二维平面上展示出来。我们可以将第一列数据(X1)看作x轴上的值,第二列数据(X2)看作y轴上的值。而最后一列数据即为分类标签。根据标签的不同,对这些点进行分类。

那么,先让我们编写代码,看下数据集的分布情况:

运行结果如下:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

从上图可以看出数据的分布情况。假设Sigmoid函数的输入记为z,那么z=w0x0 + w1x1 + w2x2,即可将数据分割开。其中,x0为全是1的向量,x1为数据集的第一列数据,x2为数据集的第二列数据。另z=0,则0=w0 + w1x1 + w2x2。横坐标为x1,纵坐标为x2。这个方程未知的参数为w0,w1,w2,也就是我们需要求的回归系数(最优参数)。

2、训练算法

在编写代码之前,让我们回顾下梯度上升迭代公式:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

将上述公式矢量化:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

根据矢量化的公式,编写代码如下:

运行结果如图所示:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

可以看出,我们已经求解出回归系数[w0,w1,w2]。

通过求解出的参数,我们就可以确定不同类别数据之间的分隔线,画出决策边界。

3、绘制决策边界

我们已经解出了一组回归系数,它确定了不同类别数据之间的分隔线。现在开始绘制这个分隔线,编写代码如下:

运行结果如下:

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

这个分类结果相当不错,从上图可以看出,只分错了几个点而已。但是,尽管例子简单切数据集很小,但是这个方法却需要大量的计算(300次乘法)。因此下篇文章将对改算法稍作改进,从而减少计算量,使其可以应用于大数据集上。

四、总结

Logistic回归的一般过程:

  • 收集数据:采用任意方法收集数据。
  • 准备数据:由于需要进行距离计算,因此要求数据类型为数值型。另外,结构化数据格式则最佳。
  • 分析数据:采用任意方法对数据进行分析。
  • 训练算法:大部分时间将用于训练,训练的目的是为了找到最佳的分类回归系数。
  • 测试算法:一旦训练步骤完成,分类将会很快。
  • 使用算法:首先,我们需要输入一些数据,并将其转换成对应的结构化数值;接着,基于训练好的回归系数,就可以对这些数值进行简单的回归计算,判定它们属于哪个类别;在这之后,我们就可以在输出的类别上做一些其他分析工作。

其他:

  • Logistic回归的目的是寻找一个非线性函数Sigmoid的最佳拟合参数,求解过程可以由最优化算法完成。
  • 本文讲述了Logistic回归原理以及数学推导过程。
  • 下篇文章将讲解Logistic回归的改进以及Sklearn实战内容。
  • 如有问题,请留言。如有错误,还望指正,谢谢!

PS: 如果觉得本篇本章对您有所帮助,欢迎关注、评论、赞!

本文出现的所有代码和数据集,均可在我的github上下载,欢迎Follow、Star:github.com/Jack-Cherish

参考文献:

 

weinxin
微信公众号
分享技术,乐享生活:微信公众号搜索「JackCui-AI」关注一个在互联网摸爬滚打的潜行者。
质胜文则野,文胜质则史,文质彬彬,然后君子。--- 孔子
Jack Cui

发表评论取消回复

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

目前评论:103   其中:访客  66   博主  37

    • avatar keep inner peace 来自天朝的朋友 谷歌浏览器 Windows 7 大连理工大学 电力电子研究所 4

      博主,请问我在绘制决策边界的时候,
      fig=plt.figure()
      ax=fig.add_subplot(111)
      ax.scatter(xcoord1,ycoord1,s=30,c=”green”,marker=”s”)
      ax.scatter(xcoord2,ycoord2,s=30,c=”red”)
      这是绘图的部分代码,为什么最后绘制出来的图形所有点都是红色,而且点的形状都是圆形,也就是说ax.scatter(xcoord1,ycoord1,s=30,c=”green”,marker=”s”)这个函数设置的参数没有起到作用,这是哪里出错误了吗?

        • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 7 辽宁省沈阳市 东北大学四舍(女生)

          @keep inner peace 是的,应该是出错了,两类数据区分开。

            • avatar keep inner peace 来自天朝的朋友 谷歌浏览器 Windows 7 大连理工大学 电力电子研究所 4

              @Jack Cui 谢谢博主指点,调试好了

          • avatar xuezha 来自天朝的朋友 谷歌浏览器 Windows 10 天津市 联通 2

            请问xcode ycode是什么 xcord1 = []; ycord1 = [] #正样本
            xcord2 = []; ycord2 = []

              • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器  Android 7.1.1 MIX 2 Build/NMF26X 辽宁省沈阳市 联通

                @xuezha 存放数据和标签的列表啊

              • avatar 习惯性一血 谷歌浏览器 Windows 7 亚太地区 0

                第二部分中,x_alpha改变符号后,结果就没法收敛了

                • avatar 刘思宇 来自天朝的朋友 谷歌浏览器 Windows 10 江苏省无锡市 移动 1

                  你好,很喜欢你这篇文章,只是里面有很多问题想不明白,首先梯度上升算法求最优参数Weight我能理解,但是y = (-weights[0] – weights[1] * x) / weights[2]是根据z=w0x0+w1x1+w2x2 其中w0为1来的我就不理解了,z=w0x0+w1x1+w2x2 这个式子是怎么得来的啊?weights[0]weights[1]weights[2]与w0w1w2是什么关系?w0*0=0为什么还让w0=1?想了两天了还是不明白。。。

                    • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 百度网讯科技联通节点

                      @刘思宇 看下书吧,这是矩阵运算,为了方便矩阵运算需要加个w0列。

                        • avatar 刘思宇 来自天朝的朋友 谷歌浏览器 Windows 10 江苏省无锡市 移动 1

                          @Jack Cui 好的,谢谢你

                          • avatar 刘思宇 来自天朝的朋友 谷歌浏览器 Windows 10 江苏省无锡市 移动 1

                            @Jack Cui 哈哈哈哈哈哈,我被自己蠢哭了,z=w0x0+w1x1+w2x2我把x0理解成了乘以0,x1理解成乘以1,难怪想不明白,再次感谢博主的提醒。

                              • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市 百度网讯科技联通节点

                                @刘思宇 没事,这个是网站编辑器的问题,不支持那种格式的输入。

                                  • avatar 刘思宇 来自天朝的朋友 谷歌浏览器 Windows 10 江苏省无锡市 移动 1

                                    @Jack Cui 嘿嘿,还是我自己不仔细

                            • avatar twpsuperman 来自天朝的朋友 QQ浏览器 Windows 7 重庆市 电信 4

                              for k in range(maxCycles):
                              请问博主有什么方法能够确定这里的迭代次数,迭代次数多大效果最佳呢?

                                • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 辽宁省沈阳市 联通

                                  @twpsuperman 这属于超参数,需要自己设定。可以通过交叉验证等方法找到合适的迭代次数。

                                • avatar 若游 来自天朝的朋友 搜狗浏览器 Windows 8.1 江苏省南京市 移动 4

                                  学长,ax.plot(x, y)处应该改为y.tranapose()吧,x为(60,1)矩阵,y为(1,60)矩阵

                                  • avatar 若游 来自天朝的朋友 搜狗浏览器 Windows 8.1 江苏省南京市 移动 4

                                    weight[1]属性是(1,1)矩阵,为什么可以和X相乘

                                    • avatar 若游 来自天朝的朋友 搜狗浏览器 Windows 8.1 江苏省南京市 移动 4

                                      一直报错ValueError: x and y must have same first dimension, but have shapes (60,) and (1, 60)

                                        • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 7 辽宁省沈阳市 东北大学三舍南(研究生)

                                          @若游 我没看懂你想问什么,你自己改了?

                                            • avatar 若游 来自天朝的朋友 搜狗浏览器 Windows 8.1 江苏省镇江市 电信 4

                                              @Jack Cui 按照你的代码放到pycharm中运行一直报错ValueError: x and y must have same first dimension, but have shapes (60,) and (1, 60)

                                                • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 7 辽宁省沈阳市 东北大学三舍南(研究生)

                                                  @若游 这个就是x,y的shape不同,你看下,是不是自己写错代码了。直接用git clone或者复制的代码。
                                                  如果还不行,reshape下试试。

                                                • avatar 若游 来自天朝的朋友 搜狗浏览器 Windows 8.1 江苏省镇江市 电信 4

                                                  @Jack Cui 学长,知道了,忘记加getA()了,所以x长度为60,而y是1个值

                                              • avatar Pinkman 来自天朝的朋友 谷歌浏览器 Windows 10 黑龙江省哈尔滨市 联通 3

                                                sigmoid函数里不是θ的转置*x吗?为什么代码是sigmoid(dataMatrix * weights) ?

                                                  • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Linux 辽宁省沈阳市 联通GSM/WCDMA/LTE共用出口

                                                    @Pinkman 这是这种表示方式,默认是列向量。转到能进行矩阵相乘的形式即可。

                                                  • avatar Justin 来自天朝的朋友 谷歌浏览器 Mac OS X 10_13_6 北京市 方正宽带 1

                                                    博主你好,我有个地方没看明白,就是使用梯度上升法迭代计算模型参数的时候,为什么直接迭代了max_cycles次,而没有监测目标函数J是否达到了最大值呢?

                                                      • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器  Android 8.0.0 MIX 2 Build/OPR1.170623.027 辽宁省沈阳市 联通GSM/WCDMA/LTE共用出口

                                                        @Justin 这就是个简单的demo,默认迭代max_cycles可以求出最优解,也算是相当于一个超参数。

                                                          • avatar Justin 来自天朝的朋友 谷歌浏览器 Mac OS X 10_13_6 北京市 电信通 1

                                                            @Jack Cui 嗯嗯,明白了,谢谢回复哈