用pytorch怎样实现自定义网络层,要点有哪些
Admin 2022-08-25 群英技术资讯 924 次浏览
本篇内容介绍了“用pytorch怎样实现自定义网络层,要点有哪些”的有关知识,在实际项目的操作过程或是学习过程中,不少人都会遇到这样的问题,接下来就让小编带大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。
首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
输入一些数据,验证一下网络是否能正常工作:
layer = CenteredLayer() print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
输出结果如下:
tensor([-2., -1., 0., 1., 2.])
运行正常,表明网络没有问题。
现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer()) Y = net(torch.rand(4, 8)) print(Y.mean()) # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可
结果如下:
tensor(-5.5879e-09, grad_fn=<MeanBackward0>)
这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。
这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
接下来实例化类并访问其模型参数:
linear = MyLinear(5, 3) print(linear.weight)
结果如下:
Parameter containing:
tensor([[-0.3708, 1.2196, 1.3658],
[ 0.4914, -0.2487, -0.9602],
[ 1.8458, 0.3016, -0.3956],
[ 0.0616, -0.3942, 1.6172],
[ 0.7839, 0.6693, -0.8890]], requires_grad=True)
而后输入一些数据,查看模型输出结果:
print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
[1.3514, 0.0968, 0.6667]])
我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
[0.2567]])
我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。
在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。
层可以有局部参数,这些参数可以通过内置函数创建。
《动手学深度学习》 — 动手学深度学习 2.0.0-beta0 documentation
https://zh-v2.d2l.ai/
#创建自己的网络 import models model = models.__dict__["resnet50"](pretrained=True) for index ,(name, param) in enumerate(model.named_parameters()): print( str(index) + " " +name)
结果如下:
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
在实际中遇到一个时间处理问题,需要将 Sep 06, 2014 19:30 (UTC 时间) 和 当前时间比较早晚,知道 此 2014-09-06 19:30 格
生成器的使用在Python中,如果一个函数定义的内部使用了yield关键字,那么在执行函数的时候返回的是一个生成器,而不是常规函数的返回值。我们先来看一个常规函数的定义,下面的函数f()通过return语句返回1,那么print打印的就是数字1。deff():ret...
包管理工具是用来对一些应用程序的包进行管理的工具,比如nodejs使用npm,yarn来进行包管理,linux使用apt来进行包管理。python包管理工具或许不如他们有名(实际上pip的大名比前几位更响亮),但绝对比他们好用易用。没错,小编这里要说的就是pip,接下来的这篇文章,我们将对pip使用_来自Python3 教程,w3cschool编程狮。
Spark部署模式分为Local模式(本地单机模式)和集群模式,在Local模式下,常用于本地开发程序与测试,而集群模式又分为Standalone模式(集群单机模式)、Yarn模式和Mesos模式,关于这三种集群模式的相关介绍具体如下:
函数input()让程序暂停运行,等待用户输入一些文本。获取用户输入后,Python将其存储在一个变量中,以方便使用。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008