CIFAR-10 分类问题 (基于 ResNet18 预训练模型)
CIFAR-10 分类问题 (基于 ResNet18 预训练模型)
项目源码存放在 GitHub 库 isKage/cifar10-classification
PyTorch
Python
torch
nn
ResNet
ResNet18
pre-trained
针对 CIFAR-10 分类问题,搭建神经网络:AlexNet
、GoogLeNet
、ResNet
、ResNet18
。最后选择预训练后的 ResNet18
进行该问题的训练、验证和测试。包含自定义数据集 Dataset
类、自定义训练、验证和测试函数、自定义结果表格函数等。
kaggle: private score = 0.68100, ranked about 71. (just training once as using cpu)
device: cpu
项目目录
1 | ├── README.md |
1 下载至本地
在终端运行
1 | git clone https://github.com/isKage/cifar10-classification.git |
2 安装依赖和数据集
2.1 pip 安装依赖
在项目根目录下终端输入
1 | pip install -r requirements.txt |
2.2 kaggle 下载数据集
教程见 从 kaggle 下载数据集 (mac & win)。
3 本地配置 config.py
在 config.py
中配置相关参数。例如数据集路径。相关配置均已配置好,但需要自己配置数据集的位置。
在 _parse()
方法中,需修改 cifar 数据集的路径。例如我的配置:cifar-10
文件夹放在用户目录下的 AllData/competitions/
下。
1 | if config.real_or_try == "real": |
4 训练
注意,默认的数据集为模拟数据集,故如果想在完整数据集训练,在指定路径后还需传入参数
--real_or_try=real
,或者直接在config.py
中
修改默认
4.1 解压数据集
在第 3 步设置完成数据集下载的路径后,终端输入
1 | python main.py unzip |
即可解压数据集。
4.2 训练
使用 fire
库方便的在终端中进行训练、测试过程。可以在 config.py
中输入默认参数。例如:model
为选择模型,默认使用 "ResNet18"
模型,
会自动进行下载,下载的预训练模型参数保存在 checkpoints/
文件夹里。
在终端运行
1 | python main.py train |
可以使用 --<参数名>=参数值
在终端覆盖默认参数
1 | python main.py train model=AlexNet # 指定 AlexNet 为模型 |
4.3 可视化
终端运行
1 | tensorboard --logdir=./logs # http://localhost:6006/ |
打开浏览器观察训练过程可视化:
5 测试
终端运行
1 | python main.py test |
即可得到测试后的结果表格 result_example.csv
或 sampleSubmission.csv
(取决与使用的是模拟数据集还是完整的数据集)。
注意,测试完成后终端输入一下指令,对结果表格按照 id
进行排序。
1 | python main.py sort_csv |
最后可以将 sampleSubmission.csv
上传到 kaggle CIFAR-10 competition 。

6 友链
- 关注我的知乎账号 Zhuhu 不错过我的笔记更新。
- 我会在个人博客 isKage`Blog 更新相关项目和学习资料。