您的位置:首页 > 其它

pytorch之torch.gather方法

2018-03-24 11:41 507 查看

首先,先给出torch.gather函数的函数定义:

        torch.gather(input, dim, index, out=None) → Tensor

官方给出的解释是这样的:                    沿给定轴dim,将输入索引张量index指定位置的值进行聚合。                    对一个3维张量,输出可以定义为:                out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0                out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1                out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3刚开始看上去有点难以理解,但经过研究之后发现原来这个想表述的很简单,先给出几个代码例子让大家自行体会一下。
>>> import torch
>>> a = torch.Tensor([[1,2],[3,4]])
>>> a
1  2
3  4
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
>>> b
1  1
4  3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,0],[1,0]]))
>>> b
2  1
4  3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,1],[1,0]]))
>>> b
2  2
4  3
[torch.FloatTensor of size 2x2]

很容易就会发现 torch.gather(input, dim, index, out=None)中的dim表示的就是第几维度,在这个二维例子中,如果dim=0,那么它表示的就是你接下来的操作是对于第一维度进行的,也就是行;如果dim=1,那么它表示的就是你接下来的操作是对于第二维度进行的,也就是列。index的大小和input的大小是一样的,他表示的是你所选择的维度上的操作,比如这个例子中
a = torch.Tensor([[1,2],[3,4]])  b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])) 其中, dim=1,表示的是在第二维度上操作。
index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,也就是对应的是[1,1]; [1,0]就是第二行对应元素的下标,也就是对应的是[4,3]。
!!!特别注意一下,index的类型必须是LongTensor类型的。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: