Pytorch中gather的用法是什么?
Admin 2021-08-24 群英技术资讯 1044 次浏览
这篇文章给大家分享的是有关Pytorch中gather的用法的内容,很多新手对于gather的用法不是很了解,因此分享一些给大家做个参考,希望大家阅读完这篇能有收获,接下来一起跟随小编看看吧。
gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor
一般来说有三个参数:输入的变量input、指定在某一维上聚合的dim、聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型)。
#参数介绍: input (Tensor) -C The source tensor dim (int) -C The axis along which to index index (LongTensor) -C The indices of elements to gather out (Tensor, optional) -C Destination tensor #当输入为三维时的计算过程: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 #样例: t = torch.Tensor([[1,2],[3,4]]) torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) # 1 1 # 4 3 #[torch.FloatTensor of size 2x2]
用下面的代码在二维上做测试,以便更好地理解
t = torch.Tensor([[1,2,3],[4,5,6]]) index_a = torch.LongTensor([[0,0],[0,1]]) index_b = torch.LongTensor([[0,1,1],[1,0,0]]) print(t) print(torch.gather(t,dim=1,index=index_a)) print(torch.gather(t,dim=0,index=index_b))
输出为:
>>tensor([[1., 2., 3.],
[4., 5., 6.]])
>>tensor([[1., 1.],
[4., 5.]])
>>tensor([[1., 5., 6.],
[4., 2., 3.]])
由于官网给的计算过程不太直观,下面给出较为直观的解释:
对于index_a,dim为1表示在第二个维度上进行聚合,索引为列号,[[0,0],[0,1]]表示结果的第一行取原数组第一行列号为[0,0]的数,也就是[1,1],结果的第二行取原数组第二行列号为[0,1]的数,也就是[4,5],这样就得到了输出的结果[[1,1],[4,5]]。
对于index_b,dim为0表示在第一个维度上进行聚合,索引为行号,[[0,1,1],[1,0,0]]表示结果的第一行第d(d=0,1,2)列取原数组第d列行号为[0,1,1]的数,也就是[1,5,6],类似的,结果的第二行第d列取原数组第d列行号为[1,0,0]的数,也就是[4,2,3],这样就得到了输出的结果[[1,5,6],[4,2,3]]
接下来以index_a为例直接用官网的式子计算一遍加深理解:
output[0,0] = input[0,index[0,0]] #1 = input[0,0] output[0,1] = input[0,index[0,1]] #1 = input[0,0] output[1,0] = input[1,index[1,0]] #4 = input[1,0] output[1,1] = input[1,index[1,1]] #5 = input[1,1]
注
以下两种写法得到的结果是一样的:
r1 = torch.gather(t,dim=1,index=index_a)
r2 = t.gather(1,index_a)
补充:Pytorch中的torch.gather函数的个人理解
在pytorch中,gather()函数的作用是将数据从input中按index提出,我们看gather函数的的官方文档说明如下:
torch.gather(input, dim, index, out=None) → Tensor Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor) -C The source tensor dim (int) -C The axis along which to index index (LongTensor) -C The indices of elements to gather out (Tensor, optional) -C Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2]
可以看出,在gather函数中我们用到的主要有三个参数:
1)input:输入
2)dim:维度,常用的为0和1
3)index:索引位置
a=t.arange(0,16).view(4,4) print(a) index_1=t.LongTensor([[3,2,1,0]]) b=a.gather(0,index_1) print(b) index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t() c=a.gather(1,index_2) print(c)
输出如下:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[12, 9, 6, 3]])tensor([[ 0],
[ 5],
[10],
[15]])
在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。
1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];
2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T;
gather函数在提取数据时主要靠dim和index这两个参数,dim=1时将input看为n×1阶矩阵,index看为k×1阶矩阵,取index每行元素对input中每行进行列索引(如:index某行为[1,3,0],对应的input行元素为[9,8,7,6],提取后的结果为[8,6,9]);
同理,dim=0时将input看为1×n阶矩阵,index看为1×k阶矩阵,取index每列元素对input中每列进行行索引。
gather函数提取后的矩阵阶数和对应的index阶数相同。
关于Pytorch中gather的用法就介绍到这,上述实例仅供参考,感兴趣的朋友可以参考学习,希望能对大家有帮助,想要了解更多gather的用法,大家可以关注其他文章。
文本转载自脚本之家
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了Python机器学习三大件之一numpy,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好地帮助哟.需要的朋友可以参考下
这篇文章主要介绍了Python在不同场景合并多个Excel的方法,文章围绕主题总共分享了三种方法,具有一定的参考价值,需要的小伙伴可以参考一下
python APScheduler定时任务执行,下文有实例供大家参考,对大家了解操作过程或相关知识有一定的帮助,而且实用性强,希望这篇文章能帮助大家,下面我们一起来了解看看吧。
自己写 Python 也有四五年了,一直是用自己的“强迫症”在维持自己代码的质量。都有去看 Google 的 Python 代码规范,对这几年的工作经验,做个简单的笔记,如果你也在学 Python,准备要学习 Python,希望这篇文章对你有用。
在写Python的时候经常会遇到时间格式的问题,首先就是最近用到的时间戳(timestamp)和时间字符串之间的转换。所谓时间戳,就是从 1970年1
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008