Tensorflow实战(一):打响深度学习的第一枪 – 手写数字识别(Tensorboard可视化)

2018年1月23日22:34:49 53 41,102 °C
摘要

为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试。同时使用Tensorboard对训练过程进行可视化,算是打响学习Tensorflow的第一枪啦。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

一、前言

为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试。同时使用Tensorboard对训练过程进行可视化,算是打响学习Tensorflow的第一枪啦。

看本文之前,希望你已经具备机器学习和深度学习基础。

机器学习基础可以看我的系列博文:

https://cuijiahua.com/blog/ml/

深度学习基础可以看吴恩达老师的公开课:

http://mooc.study.163.com/smartSpec/detail/1001319001.htm

二、MNIST数据集简介

当我们学习新的编程语言时,通常第一个程序就是打印输出著名的“Hello World!”。在深度学习中,MNIST数据集就相当于Hello World。

MNIST是一个简单的计算机视觉数据集,它包含手写数字的图像集:

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

 

数据集:

  • train-images-idx3-ubyte 训练数据图像 (60,000)
  • train-labels-idx1-ubyte 训练数据label
  • t10k-images-idx3-ubyte 测试数据图像 (10,000)
  • t10k-labels-idx1-ubyte 测试数据label

每张图像是28 * 28像素:

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

我们的任务是使用上面数据训练一个可以准确识别手写数字的神经网络模型,并使用Tensorflow对训练过程各个参数的变化进行可视化。

三、Tensorboard简介

本文要使用到Tensorboard,先让我们看看它究竟是用来干什么的。

当使用Tensorflow训练大量深层的神经网络时,我们希望去跟踪神经网络的整个训练过程中的信息,比如迭代的过程中每一层参数是如何变化与分布的,比如每次循环参数更新后模型在测试集与训练集上的准确率是如何的,比如损失值的变化情况,等等。如果能在训练的过程中将一些信息加以记录并可视化得表现出来,是不是对我们探索模型有更深的帮助与理解呢?

Tensorflow官方推出了可视化工具Tensorboard,可以帮助我们实现以上功能,它可以将模型训练过程中的各种数据汇总起来存在自定义的路径与日志文件中,然后在指定的web端可视化地展现这些信息。

1、Tensorboard的数据形式:

Tensorboard可以记录与展示以下数据形式:

(1)标量Scalars

(2)图片Images

(3)音频Audio

(4)计算图Graph

(5)数据分布Distribution

(6)直方图Histograms

(7)嵌入向量Embeddings

2、Tensorboard的可视化过程:

(1)首先肯定是先建立一个graph,你想从这个graph中获取某些数据的信息

(2)确定要在graph中的哪些节点放置summary operations以记录信息

使用tf.summary.scalar记录标量

使用tf.summary.histogram记录数据的直方图

使用tf.summary.distribution记录数据的分布图

使用tf.summary.image记录图像数据

.....等等

(3)operations并不会去真的执行计算,除非你告诉他们需要去run,或者它被其他的需要run的operation所依赖。而我们上一步创建的这些summary operations其实并不被其他节点依赖,因此,我们需要特地去运行所有的summary节点。但是呢,一份程序下来可能有超多这样的summary 节点,要手动一个一个去启动自然是及其繁琐的,因此我们可以使用tf.summary.merge_all去将所有summary节点合并成一个节点,只要运行这个节点,就能产生所有我们之前设置的summary data。

(4)使用tf.summary.FileWriter将运行后输出的数据都保存到本地磁盘中

(5)运行整个程序,并在命令行输入运行tensorboard的指令,之后打开web端可查看可视化的结果

考虑多类情况。非onehot,标签是类似0 1 2 3...n这样。而onehot标签则是顾名思义,一个长度为n的数组,只有一个元素是1.0,其他元素是0.0。例如在n为4的情况下,标签2对应的onehot标签就是 0.0 0.0 1.0 0.0使用onehot的直接原因是现在多分类cnn网络的输出通常是softmax层,而它的输出是一个概率分布,从而要求输入的标签也以概率分布的形式出现,进而算交叉熵之类。

四、手写数字识别

现在,我们使用最基础的手写数字识别。

1、准备数据集、定义超参数等准备工作

(1)首先是导入需要使用的包:

(2)定义超参数

如果你问,这个超参数为啥要这样设定,如何选择最优的超参数?这个问题此处先不讨论,超参数的选择在机器学习建模中最常用的方法就是“交叉验证法”。而现在假设我们已经获得了最优的超参数,设置学利率为0.001,dropout的保留节点比例为0.9,最大循环次数为1000。

另外,还要设置两个路径,第一个是数据下载下来存放的地方,一个是summary输出保存的地方。

(3)GPU设置

这里使用GPU进行训练,如果使用cpu,可以略过此步。如果使用GPU建议进行设置。

上述代码的意思是使用GPU设备0,最多给GPU分配总共内存的百分之33,并且允许GPU按需申请内存。也就是说,假设一个程序使用一块GPU内存百分之10就够了,如果我们没有指定allow_growth=True,那么程序会直接占用GPU内存的百分之33,因为这个是我们给它分配的。如果我们连0.33,也就是GPU内存的百分之33都没有指定,那么程序会直接占用整个GPU设备0。虽然占用这么多没有用,但是我就占着,属于“占着茅坑不拉屎”。所以,为了充分利用资源,特别是一帮人使用一个服务器的时候,指定下这些参数就很有必要了。

(4)下载数据下载数据是直接调用了tensorflow提供的函数read_data_sets,输入两个参数,第一个是下载到数据存储的路径,第二个one_hot表示是否要将类别标签进行独热编码。它首先回去找制定目录下有没有这个数据文件,没有的话才去下载,有的话就直接读取。所以第一次执行这个命令,速度会比较慢,因为没有数据集,需要进行下载。

2、数据处理

(1)创建tensorflow默认会话:

为了使设置的GPU参数生效,我们需要在创建会话的时候传入这个config参数。

(2)创建输入数据的占位符,分别创建特征数据x,标签数据y_

在tf.placeholder()函数中传入了3个参数,第一个是定义数据类型为float32;第二个是数据的大小,特征数据是大小784的向量,标签数据是大小为10的向量,None表示不定死大小,到时候可以传入任何数量的样本;第3个参数是这个占位符的名称。

mnist下载好的数据集就是很多个1*784的向量,就是已经对28*28的图片进行了向量化处理。

(3)使用tf.summary.image保存图像信息

前面也说了,特征数据其实就是图像的像素数据拉升成一个1*784的向量,现在如果想在tensorboard上还原出输入的特征数据对应的图片,就需要将拉升的向量转变成28 * 28 * 1的原始像素了,于是可以用tf.reshape()直接重新调整特征数据的维度:

将输入的数据转换成[28 * 28 * 1]的shape,存储成另一个tensor,命名为image_shaped_input。
为了能使图片在tensorbord上展示出来,使用tf.summary.image将图片数据汇总给tensorbord。
tf.summary.image()中传入的第一个参数是命名,第二个是图片数据,第三个是最多展示的张数,此处为10张。

3、初始化参数并保存参数信息到summary

(1)初始化参数w和b

在构建神经网络模型中,每一层中都需要去初始化参数w,b,为了使代码简介美观,最好将初始化参数的过程封装成方法function。 创建初始化权重w的方法,生成大小等于传入的shape参数,标准差为0.1,遵循正态分布的随机数,并且将它转换成tensorflow中的variable返回。

创建初始换偏执项b的方法,生成大小为传入参数shape的常数0.1,并将其转换成tensorflow的variable并返回。

(2)记录训练过程参数变化

我们知道,在训练的过程在参数是不断地在改变和优化的,我们往往想知道每次迭代后参数都做了哪些变化,可以将参数的信息展现在tenorbord上,因此我们专门写一个方法来收录每次的参数信息。

4、构建神经网络层

(1)创建第一层隐藏层

创建一个构建隐藏层的方法,输入的参数有:

  • input_tensor:特征数据
  • input_dim:输入数据的维度大小
  • output_dim:输出数据的维度大小(=隐层神经元个数)
  • layer_name:命名空间
  • act=tf.nn.relu:激活函数(默认是relu)

调用隐层创建函数创建一个隐藏层:输入的维度是特征的维度784,隐藏层的神经元个数是500,也就是输出的维度。

(2)创建一个dropout层

随机关闭掉hidden1的一些神经元,并记录keep_prob,减少保存参数,防止过拟合。

(3)创建一个输出层

输入的维度是上一层的输出:500,输出的维度是分类的类别种类:10,激活函数设置为全等映射identity。(暂且先别使用softmax,会放在之后的损失函数中一起计算)

5、创造损失函数

使用tf.nn.softmax_cross_entropy_with_logits来计算softmax并计算交叉熵损失,并且求均值作为最终的损失值。

6、训练

首先,使用AdamOptimizer优化器训练模型,最小化交叉熵损失

然后,计算准确率,并用tf.summary.scalar记录准确率

7、所有变量初始化

将所有的summaries合并,并且将它们写到之前定义的log_dir路径

8、送入数据集

feed_dict用于获取数据,如果是train==true,也就是进行训练的时候,就从mnist.train中获取一个batch大小为100样本,并且设置dropout值为0.9。如果是不是train==false,则获取minist.test的测试数据,并且设置dropout为1,即保留所有神经元开启。

同时,每隔10步,进行一次测试,并打印一次测试数据集的准确率,然后将测试数据集的各种summary信息写进日志中。 其余的时候,都是在进行训练将训练集的summary信息并写到日志中。

整体程序浏览:

9、运行程序

运行整个程序,在程序中定义的summary node就会将要记录的信息全部保存在指定的logdir路径中了,训练的记录会存一份文件,测试的记录会存一份文件。

运行程序,如果使用GPU进行训练,等待几分钟应该就OK了。

运行效果如下图所示:

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

可以看到,随着迭代次数的增加,准确率也在提高。

与此同时,在运行的时候,我们就可以打开Tensorboard查看训练状态。使用如下指令:

上述指令logdir指定了存储log的路径,在程序里设置的路径。port指定了查看端口,此处设为8008。

运行上述指令后,我们就可以在浏览器查看Tensorboard了。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

如果是远程登陆,可以在地址栏输入服务器IP地址加端口号,例如:1.10.12.13:8008,如果是本地登陆,在地址栏输入localhost:8008即可。

于是我们可以从这个web端看到所有程序中定义的可视化信息了。

五、Tensorboard Web端解释

看到最上面橙色一栏的菜单,分别有7个栏目,都一一对应着我们程序中定义信息的类型。

1、SCALARS 

展示的是标量的信息,我程序中用tf.summary.scalars()定义的信息都会在这个窗口。 回顾本文程序中定义的标量有:准确率accuracy,dropout的保留率,隐藏层中的参数信息,已经交叉熵损失。这些都在SCLARS窗口下显示出来了。

点开accuracy,红线表示test集的结果,蓝线表示train集的结果,可以看到随着循环次数的增加,两者的准确度也在增加,直达1000次时会到达0.967左右。

蓝线有大幅度震动是因为batch的设置问题,在每个batch里,训练效果好,但是换了一个新batch准确率就会下降,但是整体趋势还是增加的。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

点开dropout,红线表示的测试集上的保留率始终是1,蓝线始终是0.9。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

点开layer1,查看第一个隐藏层的参数信息。 

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

以上,第一排是偏执项b的信息,随着迭代的加深,最大值越来越大,最小值越来越小,与此同时,也伴随着方差越来越大,这样的情况是我们愿意看到的,神经元之间的参数差异越来越大。因为理想的情况下每个神经元都应该去关注不同的特征,所以他们的参数也应有所不同。 

第二排是权值w的信息,同理,最大值,最小值,标准差也都有与b相同的趋势,神经元之间的差异越来越明显。w的均值初始化的时候是0,随着迭代其绝对值也越来越大。

点开layer2,查看第二层的参数信息。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

点开loss,可见损失的降低趋势。 

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

2、IMAGES 

在程序中我们设置了一处保存了图像信息,就是在转变了输入特征的shape,然后记录到了image中,于是在tensorflow中就会还原出原始的图片了: 

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

整个窗口总共展现了10张图片(根据代码中的参数10)

3、AUDIO 

这里展示的是声音的信息,但本案例中没有涉及到声音的。

4、GRAPHS 

这里展示的是整个训练过程的计算图graph,从中我们可以清洗地看到整个程序的逻辑与过程。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

单击某个节点,可以查看属性,输入,输出等信息。

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

单击节点上的“+”字样,可以看到该节点的内部信息。 

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

5、DISTRIBUTIONS 

这里查看的是神经元输出的分布,有激活函数之前的分布,激活函数之后的分布等。 

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

6、HISTOGRAMS 

也可以看以上数据的直方图:

Tensorflow实战(一):打响深度学习的第一枪 - 手写数字识别(Tensorboard可视化)

7、EMBEDDINGS 

展示的是嵌入向量的可视化效果,本案例中没有使用这个功能。

六、总结

  • 本文主要使用手写数字识别的小案例来讲解了如何初步使用Tensorflow的可视化工具Tensorboard。
  • 如有问题,请留言。如有错误,还望指正,谢谢!

 

PS: 如果觉得本篇本章对您有所帮助,欢迎关注、评论、赞!

本文出现的所有代码和数据集,均可在我的github上下载,欢迎Follow、Star:点击查看

 

参考资料:

  • http://blog.topspeedsnail.com/archives/10377
  • http://blog.csdn.net/sinat_33761963/article/details/62433234
weinxin
微信公众号
分享技术,乐享生活:微信公众号搜索「JackCui-AI」关注一个在互联网摸爬滚打的潜行者。
Jack Cui

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

目前评论:53   其中:访客  31   博主  22

    • avatar 田田 来自天朝的朋友 火狐浏览器5.0 Windows 10 四川省 移动 0

      tensorboad打开时乱码,是win10上做的,是怎么回事啊,能解答下吗,万分感谢
      (tensorflow) C:\Users\田田\.spyder-py3>tensorboard –logdir=./MNIST_LOG –port=8080
      TensorBoard 1.11.0 at http://DESKTOP-5HGQB94:8080 (Press CTRL+C to quit)
      W1021 13:26:55.740218 Thread-1 application.py:300] path /[[_traceDataUrl]] not found, sending 404

      W1021 13:26:57.697982 3184 application.py:300] path /[[getCompareSeqImageSrc(seqfeat.name, seqNumber)]] not found, sending 404

        • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 7 辽宁省沈阳市 东北大学三舍南(研究生)

          @田田 看你用的是火狐浏览器。换下浏览器吧,chrome的。具体参考:
          https://blog.csdn.net/u010899985/article/details/77948692

        • avatar celia 来自天朝的朋友 谷歌浏览器 Windows 8.1 江苏省 联通 3

          cannot import name ‘_message’ from ‘google.protobuf.pyext’ (c:\users\user-1\anaconda3\lib\site-packages\google\protobuf\pyext\__init__.py)
          我在用Tensorboard生成网址时总是出现这种情况,尝试了好多方法,这是怎么回事啊

            • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

              @celia protobuf版本问题,这个很难受,可以编译源码解决,你是用anaconda吗?

                • avatar celia 来自天朝的朋友 谷歌浏览器  MI 8 Build/OPM1.171019.026 江苏省南京市 联通 3

                  @Jack Cui 对我用的就是anaconda版本,这个能解决。。😭网上就没我这种错误。。。

                    • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

                      @celia 这个错误我见过,你用google搜,有的。我解决过,就是一个protobuf的依赖问题,你新建个虚拟环境测试下。或者conda search下版本看看,找到对应的版本。

                        • avatar celia 来自天朝的朋友 QQ浏览器  MI 8 Build/OPM1.171019.026 江苏省南京市 联通 3

                          @Jack Cui 好的我下午试试,直接搜protobuf的依赖问题吗

                          • avatar celia 来自天朝的朋友 QQ浏览器  MI 8 Build/OPM1.171019.026 江苏省南京市 联通 3

                            @Jack Cui 对,预算他没说,我估计2万左右吧

                            • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

                              @celia 2万买啥服务器啊,直接买个好点的电脑吧,上个titan xp.

                            • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

                              @celia 带着报错搜。

                                • avatar celia 来自天朝的朋友 QQ浏览器  MI 8 Build/OPM1.171019.026 江苏省南京市 联通 3

                                  @Jack Cui 大佬能问下,我们老师让我们看下电脑配置,找一台能跑深度的电脑,有啥推荐的吗

                                  • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

                                    @celia 深度学习服务器?预算多少啊?

                                    • avatar celia 来自天朝的朋友 谷歌浏览器  MI 8 Build/OPM1.171019.026 江苏省南京市 联通 3

                                      @Jack Cui 电脑有啥推荐啊

                                      • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_3 北京市 百度网讯科技联通节点

                                        @celia 电脑无所谓,现在差不多的就行。主要是显卡,显卡买好点的就ok了。

                                • avatar Steven 来自天朝的朋友 谷歌浏览器 Windows 7 山西省太原市 移动 1

                                  问下博主,下面的问题该怎么解决啊?博主用的是什么运行环境啊?python、tensorflow版本分别是多少?
                                  运行出错:
                                  read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
                                  Instructions for updating:
                                  Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

                                    • avatar Jack Cui Admin 来自天朝的朋友 Safari浏览器 Mac OS X 10_14_4 北京市 百度网讯科技联通节点

                                      @Steven 你这个只是警告吧?

                                        • avatar Steven 来自天朝的朋友 谷歌浏览器 Windows 10 山西省晋城市 移动 1

                                          @Jack Cui 不是警告,就是运行错误。我是直接用python3.6,用pip安装的tensorflow。

                                            • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市丰台区 联通

                                              @Steven 哦哦,那这个没关系,应该就是换接口了,你找下你的版本的对应接口就行。我的版本找不到了,换电脑了,这是好久前的文章了。

                                        • avatar 蔓越莓曲奇 来自天朝的朋友 谷歌浏览器 Windows 10 浙江省温州市 移动 1

                                          ModuleNotFoundError: No module named ‘tensorflow.examples.tutorials’
                                          这种报错是为什么TensorFlow版本不对么

                                            • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Mac OS X 10_14_4 北京市 百度网讯科技联通节点

                                              @蔓越莓曲奇 这个问题可能是:环境没有配置好,看看site-packages有这个包没有,有的话,还import失败,说明环境有问题,可以用sys.append方法加下环境试试。
                                              在这就是版本不对,可以升级下tf试一试。

                                            • avatar cello Singapore 谷歌浏览器 Windows 10 新加坡 0

                                              谢谢dalao 想请教一下怎么用这个训练好的模型测试自己的图片呢?

                                                • avatar Jack Cui Admin 来自天朝的朋友 谷歌浏览器 Windows 10 北京市昌平区 联通

                                                  @cello 需要根据自己的图片,对网络进行调整。

                                                • avatar 谁锟斤拷锟洁春-锟斤拷锟斤拷茫 来自天朝的朋友 Safari浏览器 Mac OS X Lion 10_15_7 北京市 联通 0

                                                  TensorFlow哪个版本?

                                                  • avatar bbyte 来自天朝的朋友 Safari浏览器 Mac OS X Lion 10_15_7 安徽省合肥市 电信 1

                                                    如果使用 M1 MacBook 来训练的话,如何修改相关代码呢?

                                                    • avatar Sakura* 来自天朝的朋友 谷歌浏览器 Windows 10 黑龙江省齐齐哈尔市 电信 1

                                                      谷歌图像里面有直线怎么办,accuracy-1等图像都有,求解

                                                        • avatar Sakura* 来自天朝的朋友 谷歌浏览器  ELZ-AN00 Build/HONORELZ-AN00 黑龙江省齐齐哈尔市 电信 1

                                                          @Sakura* 我也是这个问题

                                                        • avatar Galaxy 来自天朝的朋友 谷歌浏览器 Windows 10 天津市 天津大学 0

                                                          File “D:\Galaxy\Documents\anaconda3\envs\tf\lib\site-packages\tensorflow\python\client\session.py”, line 262, in for_fetch
                                                          raise TypeError(f’Argument fetch = {fetch} has invalid type ‘
                                                          TypeError: Argument fetch = None has invalid type “NoneType”. Cannot be None
                                                          在运行 summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))之后报的错,求大佬指点

                                                          • avatar cabbage 来自天朝的朋友 谷歌浏览器 Windows 10 广东省深圳市 联通 0

                                                            依赖都是哪些,是什么版本的