pytorch禁止计算局部梯度有几种方式,具体怎样做
Admin 2022-06-15 群英技术资讯 1532 次浏览
这篇文章主要讲解了“pytorch禁止计算局部梯度有几种方式,具体怎样做”,文中的讲解内容简单、清晰、详细,对大家学习或是工作可能会有一定的帮助,希望大家阅读完这篇文章能有所收获。下面就请大家跟着小编的思路一起来学习一下吧。torch.autogard.no_grad: 禁用梯度计算的上下文管理器。
当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True
将不用计算梯度的变量放在with torch.no_grad()里
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad Out[12]:False
使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
>>> @torch.no_grad() ... def doubler(x): ... return x * 2 >>> z = doubler(x) >>> z.requires_grad Out[13]:False
torch.autogard.enable_grad :允许计算梯度的上下文管理器
在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.
使用with torch.enable_grad()允许计算梯度
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad Out[14]:True >>> y.backward() # 计算梯度 >>> x.grad Out[15]: tensor([2.])
在禁止计算梯度下调用被允许计算梯度的函数,结果可以计算梯度
>>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad Out[16]:True
torch.autograd.set_grad_enable()
可以作为一个函数使用:
>>> x = torch.tensor([1.], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad Out[17]:False >>> torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad Out[18]:True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad Out[19]:False
单独使用这三个函数时没有什么,但是若是嵌套,遵循就近原则。
x = torch.tensor([1.], requires_grad=True)
with torch.enable_grad():
torch.set_grad_enabled(False)
y = x * 2
print(y.requires_grad)
Out[20]: False
torch.set_grad_enabled(True)
with torch.no_grad():
z = x * 2
print(z.requires_grad)
Out[21]:False
补充:pytorch局部范围内禁用梯度计算,no_grad、enable_grad、set_grad_enabled使用举例

Locally disabling gradient computation 在局部区域内关闭(禁用)梯度的计算. The context managers torch.no_grad(), torch.enable_grad(), and torch.set_grad_enabled() are helpful for locally disabling and enabling gradient computation. See Locally disabling gradient computation for more details on their usage. These context managers are thread local, so they won't work if you send work to another thread using the threading module, etc. 上下文管理器torch.no_grad()、torch.enable_grad()和 torch.set_grad_enabled()可以用来在局部范围内启用或禁用梯度计算. 在Locally disabling gradient computation章节中详细介绍了 局部禁用梯度计算的使用方式.这些上下文管理器具有线程局部性, 因此,如果你使用threading模块来将工作负载发送到另一个线程, 这些上下文管理器将不会起作用. no_grad Context-manager that disabled gradient calculation. no_grad 用于禁用梯度计算的上下文管理器. enable_grad Context-manager that enables gradient calculation. enable_grad 用于启用梯度计算的上下文管理器. set_grad_enabled Context-manager that sets gradient calculation to on or off. set_grad_enabled 用于设置梯度计算打开或关闭状态的上下文管理器.
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001A2E55A8870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> b = a * 2
>>> b
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]], grad_fn=<MulBackward0>)
>>> b.requires_grad
True
>>> b.grad
__main__:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
>>> print(b.grad)
None
>>> a.requires_grad
True
>>> a.grad
>>> print(a.grad)
None
>>>
>>> with torch.no_grad():
... c = a * 2
...
>>> c
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]])
>>> c.requires_grad
False
>>> print(c.grad)
None
>>> a.grad
>>>
>>> print(a.grad)
None
>>> c.sum()
tensor(6.1559)
>>>
>>> c.sum().backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>>
>>> b.sum()
tensor(6.1559, grad_fn=<SumBackward0>)
>>> b.sum().backward()
>>>
>>>
>>> a.grad
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
>>> a.requires_grad
True
>>>
>>>
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000002109ABC8870>
>>>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> a.requires_grad
True
>>>
>>> with torch.set_grad_enabled(False):
... b = a * 2
...
>>> b
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]])
>>> b.requires_grad
False
>>>
>>> with torch.set_grad_enabled(True):
... c = a * 3
...
>>> c
tensor([[ 0.8472, -1.1145, 2.7263, -5.2804],
[-0.5418, 6.2810, 3.1219, -5.2954],
[ 3.3649, 2.5319, 0.5350, 2.0576]], grad_fn=<MulBackward0>)
>>> c.requires_grad
True
>>>
>>> d = a * 4
>>> d.requires_grad
True
>>>
>>> torch.set_grad_enabled(True) # this can also be used as a function
<torch.autograd.grad_mode.set_grad_enabled object at 0x00000210983982C8>
>>>
>>> # 以函数调用的方式来使用
>>>
>>> e = a * 5
>>> e
tensor([[ 1.4119, -1.8574, 4.5439, -8.8006],
[-0.9030, 10.4684, 5.2031, -8.8257],
[ 5.6082, 4.2198, 0.8917, 3.4294]], grad_fn=<MulBackward0>)
>>> e.requires_grad
True
>>>
>>> d
tensor([[ 1.1296, -1.4859, 3.6351, -7.0405],
[-0.7224, 8.3747, 4.1625, -7.0606],
[ 4.4866, 3.3759, 0.7133, 2.7435]], grad_fn=<MulBackward0>)
>>>
>>> torch.set_grad_enabled(False) # 以函数调用的方式来使用
<torch.autograd.grad_mode.set_grad_enabled object at 0x0000021098394C48>
>>>
>>> f = a * 6
>>> f
tensor([[ 1.6943, -2.2289, 5.4527, -10.5607],
[ -1.0836, 12.5621, 6.2437, -10.5908],
[ 6.7298, 5.0638, 1.0700, 4.1153]])
>>> f.requires_grad
False
>>>
>>>
>>>
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要给大家分享的是关于Python文件的操作,下面有详细介绍Python文件读写原理、常用文件打开模式、文件对象的常用方法、目录的相关操作,对Python新手学习具有一定的借鉴价值,感兴趣的朋友就跟随小编一起来了解一下吧。
工作中偶尔会收到一大堆文件,名称各不相同,分析文件的时候发现有不少重复的文件,导致工作效率低下,那么,这里就写了一个python脚本实现文件去重功能,感兴趣的就一起来了解一下
线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,在线性回归分析中,只包括一个自变量和一个因变量,且二者的关系
这篇文章主要介绍了Python日志采集,在实际使用python做自动化测试过程中两种解决思路都可以使用,且都挺方便,其中对于思路1,还可以将代码进行更进一步的封装,需要的朋友可以参考下
数据库是存储和管理数据的仓库,但数据库并不能直接存储数据,数据是存储在表中的,在存储数据的过程中一定会用到数据库服务器,所谓的数据库服务器就是指在计算机上安装一个数据库管理程序,如MySQL。数据库、表、数据库服务器之间的关系,如图所示。
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008