torch.cat与torch.stack的不同在哪,使用是怎样的
Admin 2022-06-29 群英技术资讯 937 次浏览
今天这篇给大家分享的知识是“torch.cat与torch.stack的不同在哪,使用是怎样的”,小编觉得挺不错的,对大家学习或是工作可能会有所帮助,对此分享发大家做个参考,希望这篇“torch.cat与torch.stack的不同在哪,使用是怎样的”文章能帮助大家解决问题。torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。

图1 torch.cat()
torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

图2 torch.stack()
补充:torch.stack()的官方解释,详解以及例子
在pytorch中,常见的拼接函数主要是两个,分别是:
1、stack()
2、cat()
实际使用中,这两个函数互相辅助:关于cat()参考torch.cat(),但是本文主要说stack()。
函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=?) → Tensor
参数
inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。
dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。
函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等
----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape
dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
不懂的看例子,再回过头看就懂了。
1.准备2个tensor数据,每个的shape都是[3,3]
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
2.测试stack函数
print(torch.stack((T1,T2),dim=0).shape) print(torch.stack((T1,T2),dim=1).shape) print(torch.stack((T1,T2),dim=2).shape) print(torch.stack((T1,T2),dim=3).shape) # outputs: torch.Size([2, 3, 3]) torch.Size([3, 2, 3]) torch.Size([3, 3, 2]) '选择的dim>len(outputs),所以报错' IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
可以复制代码运行试试:拼接后的tensor形状,会根据不同的dim发生变化。
| dim | shape |
|---|---|
| 0 | [2, 3, 3] |
| 1 | [3, 2, 3] |
| 2 | [3, 3, 2] |
| 3 | 溢出报错 |
1、函数作用:
函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
2、存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留-C[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。
函数存在意义?
手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留-C[1.seq_len这个时间步]和-C[2.张量属性[batch_size, output_size]]。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
断点(break point)是指在代码中指定位置,当程序运行到此位置时便中断下来,并让开发者可查看此时各变量的值。
如果要考察某公司的牛奶产品质量,可以从100袋牛奶中抽取30袋,在随机数表中选中一数,并用向上、下、左、右不同的读法组成30个数,并按牛奶的标号进行检测,虽然麻烦,但很常用。在日常生活中,随机数起着很大的作用,所以很多人会专门去寻找随机数生成器。
内容介绍题目描述解题思路/算法分析/问题及解决实验代码运行结果题目描述本次实验为连接数据库的实验,并对数据库进行一些简单的操作,要实现的基本功能如下所示,要能连接并展现数据库里的数据,能够实现插入功能
这篇文章主要介绍了Python 中面向接口编程详情,Python 中的接口与大多数其它语言的处理方式不同,它们的设计复杂性也不同,关于Python 接口编程的介绍,需要的小伙伴可以参考下面文章内容
本篇文章给大家带来了关于Python的相关知识,KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法,下面一起来看一下,希望对大家有帮助。
成为群英会员,开启智能安全云计算之旅
立即注册关注或联系群英网络
7x24小时售前:400-678-4567
7x24小时售后:0668-2555666
24小时QQ客服
群英微信公众号
CNNIC域名投诉举报处理平台
服务电话:010-58813000
服务邮箱:service@cnnic.cn
投诉与建议:0668-2555555
Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 ICP核准(ICP备案)粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008