使用PyTorch训练MNIST数据集
前言
PyTorch 是一个基于 Python 的深度学习平台,它简单易用上手快的同时功能十分强大。
本篇文章首先将介绍 PyTorch 的基本数据结构 Tensor 的一些操作;随后给出神经网络中的 HelloWorld 例子:用最经典的卷积神经网络(LeNet5)训练手写数据集 MNIST
PyTorch 中的 Tensor
以下内容来自: https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html
Tensor 简单讲就是多维数组,用来表示各种维度的数据。
Tensor 的创建修改
- 创建一个未初始化的 5*3 的 Tensor
- 初始化一个随机 5*3 矩阵
|
|
- 初始化一个全零 5*3 矩阵,类型为 long
|
|
- 用数据初始化一个 tensor
|
|
|
|
- 从一个存在的 tensor 创建一个 tensor
- 返回大小
|
|
|
|
Tensor 的一些操作
运算(如加法)
+号
torch.add
1
torch.add(x, y)
提供一个参数保存结果
原地相加
改变维度,类似 numpy 中的 reshape
使用 tensor.view
改变 tensor
的大小。
|
|
返回值
对于只有一个元素的 tensor
通过 .item()
得到它的值
Tensor 转换成 Numpy
|
|
|
|
a 和 b 共享内存
Numpy 转换成 Tensor
CUDA Tensor
Tensor 可以移动到任意设备,通过 .to
方法
搭建 LeNet5 训练 MNIST 数据集
MNIST 数据集处理
虽然 PyTorch 中已经预置了 MNIST 数据集的处理代码,但是我们要有自己处理数据集的能力,特别是在学习阶段,所以本文会自己处理数据集,然后结合 PyTorch 的数据处理机制。
MNIST 数据集的结构
MNIST 数据集包含 60000 张训练用的图片,10000 张测试用的图片,每个图片均有对应的标签。每张图片的像素是 28 * 28,每个像素值的范围是 0 - 255,用 8 个比特表示。数据集有下面四个二进制文件,对应训练图片,训练标签,测试图片,测试标签:
train-images.idx3-ubyte
train-labels.idx1-ubyte
t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
图片(idx3)的格式:首先是 32 位的整数,是一个 magic 数字,接下来 32 位整数表示图片的数量,接下来的两个 32 位整数是分别是图片行数和列数,接下来是像素值,每个像素 8 位一个字节,取值 0-255。标签(idx1)的格式:首先是 32 位的整数,是一个 magic 数字,接下来 32 位整数表示标签的数量,接下来是标签的值,每个标签一个字节,取值 0-9。
更加详细的介绍可以看官方文档:http://yann.lecun.com/exdb/mnist/
下载地址:链接: https://pan.baidu.com/s/1Ve1mtx7UNq7im6xu0MVbdQ 提取码: c57g
读取数据集的代码
|
|
PyTorch 的数据处理方法
PyTorch 提供了一个 torch.utils.data.DataLoader
工具对数据进行批量(batch)化,打乱数据,并行处理的等。=DataLoader= 的定义如下:
|
|
DataLoader
中第一个参数是一个 DataSet
对象,=DataSet= 提供一个抽象的接口,我们可以继承它来处理自己的数据集。通过下面的例子来说明使用方法。
|
|
搭建 LeNet5 网络
LeNet5 网络这里就不介绍了,CNN 入门网络。网络图如下:
PyTorch 搭建网络有四种方式,下面是我喜欢的一种,其它三种可以自行搜索
|
|
上面的网络跟图片上的完全一样,图片的输入是 32*32,所以 C1 层加个 padding=2 。
定义训练和测试方法
代码参考:https://github.com/pytorch/examples/blob/master/mnist/main.py
|
|
训练过程
代码参考:https://github.com/pytorch/examples/blob/master/mnist/main.py
|
|
训练结果
训练 10 轮后,正确率到达 98.52%