MNIST 数据集输出手写数字识别 - SUSTC-XLAB/crops GitHub Wiki
前言
mnist 数据集是学习图像识别的入门,通过它可以带你了解图像的结构以及CNN网络的基本构成。mnist数据集包含 7 万张黑底白字手写数字图片,其中 55000 张为训练集,5000 张为验证集,10000 张为测试集。每张图片大小为 28*28 像素,图片中纯黑色像素值为 0,纯白色像素值为 1。数据集的标签是长度为 10 的一维数组,数组中每个元素索引号表示对应数字出现的概率。
图像处理相关函数介绍
在将 mnist 数据集作为输入喂入神经网络时,需先将数据集中每张图片变为长度784 一维数组,将该数组作为神经网络输入特征喂入神经网络。
- 使用 input_data 模块中的 read_data_sets()函数加载 mnist 数据集: from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(’./data/’,one_hot=True) - 使用 mnist.train.next_batch()函数将数据输入神经网络:mnist.train.next_batch()函数包含一个参数 BATCH_SIZE,表示随机从训 练集中抽取 BATCH_SIZE 个样本输入神经网络,并将样本的像素值和标签分别赋给 xs 和 ys。在本例中,BATCH_SIZE 设置为 200,表示一次将 200 个样本的像素值和标签分别赋值给 xs 和 ys,故 xs 的形状为(200,784),对应的 ys 的形状为(200,10)。
- 实现“Mnist 数据集手写数字识别”的常用函数:
①tf.get_collection(“”)函数表示从 collection 集合中取出全部变量生成一个列表。
②tf.add( )函数表示将参数列表中对应元素相加。
③tf.cast(x,dtype)函数表示将参数 x 转换为指定数据类型。
④tf.equal( )函数表示对比两个矩阵或者向量的元素。若对应元素相等,则返回 True;若对应元素不相等,则返回 False。
⑤tf.reduce_mean(x,axis)函数表示求取矩阵或张量指定维度的平均值。若不指定第二个参数,则在所有元素中取平均值;若指定第二个参数为 0,则在第一维元素上取平均值,即每一列求平均值;若指定第二个参数为 1,则在第二维元素上取平均值,即每一行求平均值。
⑥tf.argmax(x,axis)函数表示返回指定维度 axis 下,参数 x 中最大值索引号。 ⑦os.path.join()函数表示把参数字符串按照路径命名规则拼接。 ⑧字符串.split( )函数表示按照指定“拆分符”对字符串拆分,返回拆分列表。 ⑨tf.Graph( ).as_default( )函数表示将当前图设置成为默认图,并返回一个上下文管理器。该函数一般与 with 关键字搭配使用,应用于将已经定义好的神经网络在计算图中复现。
模型各模块介绍
神经网络模型包括前向传播过程、反向传播过程、反向传播过程中用到的正则化、指数衰减学习率、滑动平均方法的设置、以及测试模块等。
- 前向传播过程(forward.py) 前向传播过程完成神经网络的过程中,需要定义神经网络中的参数 w 和偏置 b,定义由输入到输出的网络结构。通过定义函数 get_weight()实现对参数 w 的设置,包括参数 w 的形状和是否正则化的标志。同样,通过定义函数 get_bias()实现对偏置 b 的设置。
- 反向传播过程 反向传播过程中,用 tf.placeholder(dtype, shape)函数实现训练样本 x 和样本标签 y_占位,函数参数 dtype 表示数据的类型,shape 表示数据的形状;y 表示定义的前向传播函数 forward;loss 表示定义的损失函数,一般为预测值与样本标签的交叉熵(或均方误差)与正则化损失之和;train_step 表示利用优化算法对模型参数进行优化,常用优化算法 GradientDescentOptimizer 、AdamOptimizer、MomentumOptimizer 算法,在上述代码中使用的 GradientDes centOptimizer 优化算法。接着实例化 saver 对象,其中利用 tf.initialize _all_variables().run()函数实例化所有参数模型,利用 sess.run( )函数实现模型的训练优化过程,并每间隔一定轮数保存一次模型。
- 正则化、指数衰减学习率、滑动平均方法的设置
①正则化项 regularization
当在前向传播过程中即 forward.py 文件中,设置正则化参数 regularization 为1 时,则表明在反向传播过程中优化模型参数时,需要在损失函数中加入正则化项。
②指数衰减学习率 在训练模型时,使用指数衰减学习率可以使模型在训练的前期快速收敛接近较优解,又可以保证模型在训练后期不会有太大波动。 ③滑动平均 在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮。
代码实现
实现手写体 mnist 数据集的识别任务,共分为三个模块文件,分别是描述网络结构的前向传播过程文件(mnist_forward.py)、 描述网络参数优化方法的反向传播过程文件( mnist_backward.py )、验证模型准确率的测试过程文件(mnist_test.py)。
-
前向传播过程 在前向传播过程中,需要定义网络模型输入层个数、隐藏层节点数、输出层个数,定义网络参数 w、偏置 b,定义由输入到输出的神经网络架构。实现手写体 mnist 数据集的识别任务前向传播过程如下:
由上述代码可知,在前向传播过程中,规定网络输入结点为 784 个(代表每张输入图片的像素个数), 隐藏层节点 500 个,输出节点 10 个(表示输出为数字 0-9的十分类) 。由输入层到隐藏层的参数 w1 形状为[784,500],由隐藏层到输出层的参数 w2 形状为[500,10],参数满足截断正态分布,并使用正则化,将每个参数的正则化损失加到总损失中。由输入层到隐藏层的偏置 b1 形状为长度为 500的一维数组,由隐藏层到输出层的偏置 b2 形状为长度为 10 的一维数组,初始化值为全 0。前向传播结构第一层为输入 x 与参数 w1 矩阵相乘加上偏置 b1,再经过 relu 函数,得到隐藏层输出 y1。前向传播结构第二层为隐藏层输出 y1 与参数 w2 矩阵相乘加上偏置 b2,得到输出 y。由于输出 y 要经过 softmax 函数,使其符合概率分布,故输出 y 不经过 relu 函数。
-
反向传播过程 反向传播过程实现利用训练数据集对神经网络模型训练,通过降低损失函数值,实现网络模型参数的优化,从而得到准确率高且泛化能力强的神经网络模型。 实现手写体 mnist 数据集的识别任务反向传播过程如下:
由上述代码可知,在反向传播过程中,首先引入 tensorflow、input_data、前向传播 mnist_forward 和 os 模块,定义每轮喂入神经网络的图片数、初始学习率、学习率衰减率、正则化系数、训练轮数、模型保存路径以及模型保存名称等相关信息。在反向传播函数 backword 中,首先读入 mnist,用 placeholder 给训练数据 x 和标签 y_占位,调用 mnist_forward 文件中的前向传播过程 forword()函数,并设置正则化,计算训练数据集上的预测结果 y,并给当前计算轮数计数器赋值,设定为不可训练类型。接着,调用包含所有参数正则化损失的损失函数loss,并设定指数衰减学习率 learning_rate。然后,使用梯度衰减算法对模型优化,降低损失函数,并定义参数的滑动平均。最后,在 with 结构中,实现所有参数初始化,每次喂入 batch_size 组(即 200 组)训练数据和对应标签,循环迭代 steps 轮,并每隔 1000 轮打印出一次损失函数值信息,并将当前会话加载到指定路径。最后,通过主函数 main(),加载指定路径下的训练数据集,并调用规定的 backward()函数训练模型。
-
测试过程 当训练完模型后,给神经网络模型输入测试集验证网络的准确性和泛化性。注意,所用的测试集和训练集是相互独立的。 实现手写体 mnist 数据集的识别任务测试传播过程如下:
在上述代码中,首先需要引入 time 模块、tensorflow、input_data、前向传播mnist_forward、反向传播 mnist_backward 模块和 os 模块,并规定程序 5 秒的循环间隔时间。接着,定义测试函数 test(),读入 mnist 数据集,利用 tf.Graph()复现之前定义的计算图,利用 placeholder 给训练数据 x 和标签 y_占位,调用mnist_forward 文件中的前向传播过程 forword()函数,计算训练数据集上的预测结果 y。接着,实例化具有滑动平均的 saver 对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性,然后计算模型在测试集上的准确率。在 with 结构中,加载指定路径下的 ckpt,若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率,若模型不存在,则打印出模型不存在的提示,从而 test()函数完成。通过主函数 main(),加载指定路径下的测试数据集,并调用规定的 test 函数,进行模型在测试集上的准确率验证。
Reference
助教的Tensorflow笔记--北大mooc公开课
https://blog.csdn.net/simple_the_best/article/details/75267863