pytorch之torch.gather方法
2018-03-24 11:41
507 查看
首先,先给出torch.gather函数的函数定义:
torch.gather(input, dim, index, out=None) → Tensor
官方给出的解释是这样的: 沿给定轴dim,将输入索引张量index指定位置的值进行聚合。 对一个3维张量,输出可以定义为: out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0 out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1 out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3刚开始看上去有点难以理解,但经过研究之后发现原来这个想表述的很简单,先给出几个代码例子让大家自行体会一下。>>> import torch >>> a = torch.Tensor([[1,2],[3,4]]) >>> a 1 2 3 4 [torch.FloatTensor of size 2x2] >>> b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])) >>> b 1 1 4 3 [torch.FloatTensor of size 2x2] >>> b = torch.gather(a,1,torch.LongTensor([[1,0],[1,0]])) >>> b 2 1 4 3 [torch.FloatTensor of size 2x2] >>> b = torch.gather(a,1,torch.LongTensor([[1,1],[1,0]])) >>> b 2 2 4 3 [torch.FloatTensor of size 2x2]
很容易就会发现 torch.gather(input, dim, index, out=None)中的dim表示的就是第几维度,在这个二维例子中,如果dim=0,那么它表示的就是你接下来的操作是对于第一维度进行的,也就是行;如果dim=1,那么它表示的就是你接下来的操作是对于第二维度进行的,也就是列。index的大小和input的大小是一样的,他表示的是你所选择的维度上的操作,比如这个例子中
a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])) 其中, dim=1,表示的是在第二维度上操作。
index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,也就是对应的是[1,1]; [1,0]就是第二行对应元素的下标,也就是对应的是[4,3]。!!!特别注意一下,index的类型必须是LongTensor类型的。
相关文章推荐
- pytorch + visdom CNN处理自建图片数据集的方法
- PyTorch快速搭建神经网络及其保存提取方法详解
- PyTorch参数初始化方法
- pytorch构建网络模型的4种方法
- Pytorch入门——安装快速安装方法
- python文件中的__name__=='__main__'的使用及调用其他py文件中的函数方法
- Python引用(import)文件夹下的py文件的方法
- AIchallenger scene classfication baseline implemented by PyTorch
- 使用pyside+designer将.ui文件转化为.py文件的两种方法
- pytorch cnn 识别手写的字实现自建图片数据
- python的构建工具setup.py的方法使用示例
- pytorch安装----CPU版的
- PyTorch: Softmax多分类实战
- c++中调用python脚本提示 error LNK2001: 无法解析的外部符号 __imp_Py_Initialize等错误的解决方法
- Python pygraphviz 安装方法
- pytorch loss function 总结
- 浅析Python的web.py框架中url的设定方法
- File "/Volumes/android/.repo/repo/main.py", line 531, in <module> _Main(sys.argv[1:]) 解决方法
- Pytorch 中triplet loss的写法
- PyTorch 深度学习:60分钟快速入门