PyTorch 训练 CIFAR10 分类器
原文链接:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
CIFAR10 数据集介绍
该数据集共有 60000 张彩色图像,这些图像是 3*32*32,即 3 通道彩色 32*32 的图片,分为 10 个类,每类 6000 张图
载入数据
使用 PyTorch 加载数据,本地没有会自动下载数据
torchvision 载入的数据是 PILImage 图片,将它转换为 Tensor,取值为 [-1, 1]
|
|
显示数据
|
|
定义网络
|
|
训练
训练过程的代码框架和 PyTorch 入门 一样,只是改变了一些参数
从上图可以看到正确率并不高,损失值不断波动,不再下降,通过更改 batch_size 正确率会上升几个百分点。这个时候想到修改网络。
优化
更改网络结构:
|
|
这个网络在训练 30 轮后正确率能达到 78%
参考 VGG16 网络:
|
|
在训练到十几轮的时候发现过拟合:
在全连接前加一层 dropout 可以防止,总共训练 100 轮,在差不多 50 轮后,开始出现过拟合现象,最后结果如下图,正确率最高 82%