Pytorch搭建神经网络(2)自动求导autograd、反向传播backward与计算图
基于《深度学习框架 Pytorch 入门与实践》陈云
参考 Github 的 pytorch-book 项目
参考 GitHub 的 pytorch-handbook 项目
笔记和代码存储在我的 GitHub 库中 github.com/isKage/pytorch-notes 。
torch.autograd
提供了一套自动求导方式,它能够根据前向传播过程自动构建计算图,执行反向传播。
1 autograd 的数学原理:计算图
计算图原理可以查看 cs231n 课程讲解:【计算图的原理非常重要!】或者见后文分析
英文官网 https://cs231n.github.io/
b站 课程整理 BV1nJ411z7fe 【反向传播章节】
b站 中文讲解 【子豪兄】精讲CS231N斯坦福计算机视觉公开课(2020最新)
2 autograd 的使用:requires_grad & backward
2.1 requires_grad 属性
只需要对Tensor增加一个 requires_grad=True
属性,Pytorch就会自动计算 requires_grad=True
属性的 Tensor,并保留计算图,从而快速实现反向传播。
1 2 3 4 5 6 7 8 9 10 11 x = torch.randn(2 , 3 , requires_grad=True ) x = torch.rand(2 , 3 ).requires_grad_() x = torch.randn(3 , 4 ) x.requires_grad = True print (x.requires_grad)
2.2 backward 反向传播
反向传播函数的使用:其中第一个参数 tensors
传入用于计算梯度的张量,格式和各个参数
1 torch.autograd.backward(tensors, grad_tensors=None , retain_graph=None , create_graph=False )
tensors
:用于计算梯度的Tensor,如torch.autograd.backward(y)
,等价于y.backward()
。
grad_tensors
:形状与tensors一致,对于y.backward(grad_tensors)
,grad_tensors相当于链式法则d z d x = d z d y × d y d x {\mathrm{d}z \over \mathrm{d}x}={\mathrm{d}z \over \mathrm{d}y} \times {\mathrm{d}y \over \mathrm{d}x} d x d z = d y d z × d x d y 中的d z d y {\mathrm{d}z} \over {\mathrm{d}y} d y d z 。【结合例子理解见后】
retain_graph
:计算计算图里每一个导数值时需要保留各个变量的值,retain_graph 为 True 时会保存。【结合例子理解见后】
2.2.1 requires_grad 属性的传递
例:a
需要求导,b
不需要,c
定义为 a + b
的元素加和
1 2 3 4 5 6 a = torch.randn(2 , 3 , requires_grad=True ) b = torch.zeros(2 , 3 ) c = (a + b).sum () a.requires_grad, b.requires_grad, c.requires_grad
2.2.2 is_leaf 叶子结点
对于计算图中的Tensor而言, is_leaf=True
的Tensor称为Leaf Tensor,也就是计算图中的叶子节点。
requires_grad=False
时,无需求导,故为叶子结点。
即使 requires_grad=True
但是由用户创建的时,此时它位于计算图的头部(叶子结点),它的梯度会被保留下来。
1 2 3 a.is_leaf, b.is_leaf, c.is_leaf
2.3 autograd 利用计算图计算导数
利用 autograd 计算导数,对于函数 y = x 2 e x y=x^2e^x y = x 2 e x ,它的导函数解析式为
d y d x = 2 x e x + x 2 e x \begin{equation}
\dfrac{d\ y}{d\ x} = 2xe^x + x^2e^x
\end{equation}
d x d y = 2 x e x + x 2 e x
定义计算 y 函数和计算解析式导数结果函数
1 2 3 4 5 6 7 8 9 10 def f (x ): y = x * x * torch.exp(x) return y def df (x ): df = 2 * x * torch.exp(x) + x * x * torch.exp(x) return df
1 2 3 4 5 6 x = torch.randn(2 , 3 , requires_grad=True ) y = f(x)
1 2 y.backward(gradient=torch.ones(y.size()))
1 2 3 4 5 6 7 print (x.grad) print (df(x))
x.grad & df(x)
二者是在数值上是一样的
3 反向传播与计算图
3.1 计算图原理:链式法则
根据链式法则
d z / d y = 1 , d z / d b = 1 dz/dy = 1,\ dz/db = 1 d z / d y = 1 , d z / d b = 1
d y / d w = x , d y / d x = w dy/dw = x,\ dy/dx = w d y / d w = x , d y / d x = w
d z / d x = d z / d y × d y / d x = 1 × w , d z / d w = d z / d y × d y / d w = 1 × x dz/dx = dz/dy \times dy/dx = 1 \times w,\ dz/dw = dz/dy \times dy/dw = 1 \times x d z / d x = d z / d y × d y / d x = 1 × w , d z / d w = d z / d y × d y / d w = 1 × x
只要存储结点的导数和值便可通过简单的乘法计算所有导数
按照上图构造
1 2 3 4 5 6 7 8 9 x = torch.ones(1 ) b = torch.rand(1 , requires_grad = True ) w = torch.rand(1 , requires_grad = True ) y = w * x z = y + b x.requires_grad, b.requires_grad, w.requires_grad, y.requires_grad, z.requires_grad
3.2 grad_fn 查看反向传播函数
grad_fn
可以查看这个结点的函数类型
1 2 3 4 z.grad_fn y.grad_fn w.grad_fn, x.grad_fn, b.grad_fn
grad_fn.next_functions
获取 grad_fn 的输入,返回上一步的反向传播函数
1 2 3 4 5 6 7 z.grad_fn.next_functions y.grad_fn.next_functions
3.3 retain_graph 的使用(仅叶子结点)
如果不指定 retain_graph=True
,则在反向传播后,会自动清除变量值。
例如:计算 w.grad
w 的梯度时,需要 x 的值 (d y / d w = x dy/dw = x d y / d w = x )
注意:x.requires_grad=False 不需要求导,故 x.grad
报错
1 2 3 z.backward(retain_graph=True ) print (w.grad)
1 2 3 4 z.backward() print (w.grad)
3.4 关闭反向传播
某一个节点 requires_grad
被设置为 True
,那么所有依赖它的节点 requires_grad
都是 True
。有时不需要对所有结点都反向传播(求导),从而来节省内存。
1 2 3 4 5 6 x = torch.ones(1 ) w = torch.rand(1 , requires_grad=True ) y = x * w x.requires_grad, w.requires_grad, y.requires_grad
下面我们来关闭关于 y
的反向传播
1 2 3 4 5 6 7 with torch.no_grad(): x = torch.ones(1 ) w = torch.rand(1 , requires_grad=True ) y = x * w x.requires_grad, w.requires_grad, y.requires_grad
法二:设置默认 torch.set_grad_enabled(False)
1 2 3 4 5 6 7 8 9 10 torch.set_grad_enabled(False ) x = torch.ones(1 ) w = torch.rand(1 , requires_grad = True ) y = x * w x.requires_grad, w.requires_grad, y.requires_grad torch.set_grad_enabled(True )
3.5 .data
从计算图取出Tensor的值
修改张量的数值,又不影响计算图,使用 tensor.data
方法
1 2 3 4 5 x = torch.ones(1 , requires_grad = True ) x_clone = x.data x.requires_grad, x_clone.requires_grad
3.6 存储非叶子结点的梯度
在计算图流程中,非叶子结点求导后其导数值便立刻被清除。可以使用 autograd.grad
或 hook
方法保留
1 2 3 4 5 x = torch.ones(1 , requires_grad = True ) w = torch.ones(1 , requires_grad = True ) y = w * x z = y.sum ()
1 2 3 z.backward() x.grad, w.grad, y.grad
若为叶子结点可以采用 z.backward(retain_graph=True)
的方式
1 2 3 4 5 6 7 8 x = torch.ones(1 , requires_grad = True ) w = torch.ones(1 , requires_grad = True ) y = x * w z = y.sum () torch.autograd.grad(z, y)
标准格式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def variable_hook (grad ): print ('y.grad:' , grad) x = torch.ones(1 , requires_grad = True ) w = torch.ones(1 , requires_grad = True ) y = x * w hook_handle = y.register_hook(variable_hook) z = y.sum () z.backward() hook_handle.remove()
4 案例:线性回归
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import torchimport numpy as npfrom matplotlib import pyplot as plt%matplotlib inline def get_fake_data (batch_size=16 ): x = torch.rand(batch_size, 1 ) * 5 y = x * 2 + 3 + torch.randn(batch_size, 1 ) return x, y torch.manual_seed(1000 ) x, y = get_fake_data() w = torch.rand(1 , 1 , requires_grad=True ) b = torch.zeros(1 , 1 , requires_grad=True ) losses = np.zeros(200 ) lr = 0.005 EPOCHS = 200 for epoch in range (EPOCHS): x, y = get_fake_data(batch_size=32 ) y_pred = x.mm(w) + b.expand_as(y) loss = 0.5 * (y_pred - y) ** 2 loss = loss.sum () losses[epoch] = loss.item() loss.backward() ''' 取 .data 是因为每一轮是根据随机生成的 batch_size 个点训练,但我们希望存储的是全局参数 w, b ''' ''' 故每次依据样本点更新全局参数,而不是改批次的参数 ''' w.data.sub_(lr * w.grad.data) b.data.sub_(lr * b.grad.data) w.grad.data.zero_() b.grad.data.zero_() if epoch % 10 == 0 : print ("Epoch: {} / {}, Parameters: w is {}, b is {}, Loss: {}" .format (epoch, EPOCHS, w.item(), b.item(), losses[epoch])) print ("Epoch: {} / {}, Parameters: w is {}, b is {}, Loss: {}" .format (EPOCHS, EPOCHS, w.item(), b.item(), losses[-1 ]))