本系列博文主要根据开源的thorough-pytorch项目编写,感谢datawhalechina团队的dalao们分享学习经验
PyTorch训练模型
一个神经网络的典型训练过程如下:
- 定义包含一些可学习参数(或者叫权重)的神经网络
- 在输入数据集上迭代
- 通过网络处理输入
- 计算损失函数
loss
(输出和正确答案的距离) - 将梯度反向传播给网络的参数
- 更新权重(一般使用简单的规则,如
weight = weight - learning_rate * gradient
)
下面来分别介绍其中的关键流程在PyTorch上的实现