[PyTorch] torch.squeee 和 torch.unsqueeze()
2020-01-17 13:50
295 查看
torch.squeeze
torch.squeeze(input, dim=None, out=None) → Tensor
分为两种情况: 不指定维度 或 指定维度
不指定维度
input: (A, B, 1, C, 1, D) output: (A, B, C, D)
Example
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x) >>> y.size() torch.Size([2, 2, 2])
指定维度
input: (A, 1, B)
&torch.squeeze(input, 0)
output: (A, 1, B)
input: (A, 1, B)
&torch.squeeze(input, 1)
output: (A, B)
Example
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 0) >>> y.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2])
Note:
The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.
也就是说, squeeze 所返回的 tensor 与 输入 tensor 共享内存, 所以如果改变其中一项的值另一项也会随着改变.
torch.unsqueeze
torch.unsqueeze(input, dim, out=None) → Tensor
Note: 这里与 squeeze 不同的是 unsqueeze 必须指定维度.
同时, unsqueeze 所返回的 tensor 与 输入的 tensor 也是共享内存的.
>>> import torch >>> a = torch.zeros([2, 2]) >>> a.shape torch.Size([2, 2]) >>> a tensor([[0., 0.], [0., 0.]]) >>> b = torch.unsqueeze(a, dim=0) >>> b.shape torch.Size([1, 2, 2]) >>> b tensor([[[0., 0.], [0., 0.]]]) >>> b[0, 0, 1] = 1 >>> b tensor([[[0., 1.], [0., 0.]]]) >>> a tensor([[0., 1.], [0., 0.]]) >>> b = torch.unsqueeze(a, dim=1) >>> b.shape torch.Size([2, 1, 2]) >>> b tensor([[[0., 1.]], [[0., 0.]]])
相关文章推荐
- 深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解
- pytorch: PyTorch中 使用 Tensorboard
- PyTorch 中的数据类型 torch.utils.data.DataLoader
- 【DeepLearning】【PyTorch ()】Pytorch Loss functions
- PyTorch宣布推出PyTorch Hub,以提高机器学习研究的重现性
- Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解)
- PyTorch 学习(1) 什么是 PyTorch
- torch.squeeze()、torch.unsqueeze()、torch.randn()、torch.nn.Linear()、torch.Tensor.expand_as
- 解决了PyTorch 使用torch.nn.DataParallel 进行多GPU训练的一个BUG:模型(参数)和数据不在相同设备上
- Pytorch入门(一):anaconda+pycharm+pytorch环境搭建
- pytorch-torchvision transforms
- PyTorch 深度学习【一】Ubuntu16.04 下安装 PyTorch
- pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
- Windows下pip安装PyTorch后出现“from torch._C import * ImportError: DLL load failed: 找不到指定的模块”错误的解决办法
- Pytorch:torch.optim
- pytorch1.2做CIFAR-10数据集分类详解,pytorch入门程序
- 使用pip install pytorch安装pytorch报错解决方法
- PyTorch深度学习:60分钟入门(Translation) PyTorch深度学习:60分钟入门(Translation)
- PyTorch的concat也就是torch.cat实例