基于PyTorch的深度学习基本工作流


一个完整的深度学习工作流程涉及处理数据、创建模型、优化参数和保存模型等步骤。
这篇blog提供了一个在FashionMNIST数据集上训练的一个图片分类网络实例,提供了一个较为规范的工作流程模板。
处理数据
PyTorch有两个处理数据的方法:Torch.utils.data.DataLoader和Torch.utils.data.Dataset。Dataset存储了样本及其相应的标签,而DataLoader则围绕Dataset包装了一个可迭代数据容器。
1 | import torch |
创建模型
1 | device = "cuda" if torch.cuda.is_available() else "cpu" |
优化参数
1 | # 定义损失函数,评价模型预测结果与真实结果间的差距 |
保存模型
1 | # 序列化内部状态字典(包含模型参数) |
载入模型并预测
1 | model = NeuralNetwork() |
总结
以上内容展示了完成一项深度学习基本任务所需要的全部工作流程, 在完成一项实际任务时,大致需要经过一下几个步骤:
确定数据对象,选择合适的批量大小,创建Data Loader完成数据的批量预加载(一般分为训练集和测试集,实际问题中为了调参往往会增设k折交叉验证)。
根据问题需要构造合适的网络模块,完成模型声明与实例化。
选择合适的损失函数和优化器,在合适的迭代次数内对模型进行参数更新,并不断通过验证集准确率与损失大小确定模型是否处于学习状态(即是否未发生欠拟合/过拟合)。
保存模型内部参数(保存序列化内部状态字典),需要使用时实例化模型并加载参数,即可进行预测任务。