您的位置:首页 > 产品设计 > UI/UE

[PyTorch] torch.squeee 和 torch.unsqueeze()

2020-01-17 13:50 295 查看

torch.squeeze

torch.squeeze(input, dim=None, out=None) → Tensor

分为两种情况: 不指定维度 或 指定维度

  1. 不指定维度

    input: (A, B, 1, C, 1, D) output: (A, B, C, D)

    Example

    >>> x = torch.zeros(2, 1, 2, 1, 2)
    >>> x.size()
    torch.Size([2, 1, 2, 1, 2])
    >>> y = torch.squeeze(x)
    >>> y.size()
    torch.Size([2, 2, 2])
  2. 指定维度

    input: (A, 1, B)
    &
    torch.squeeze(input, 0)
    output: (A, 1, B)

    input: (A, 1, B)
    &
    torch.squeeze(input, 1)
    output: (A, B)

    Example

    >>> x = torch.zeros(2, 1, 2, 1, 2)
    >>> x.size()
    torch.Size([2, 1, 2, 1, 2])
    >>> y = torch.squeeze(x, 0)
    >>> y.size()
    torch.Size([2, 1, 2, 1, 2])
    >>> y = torch.squeeze(x, 1)
    >>> y.size()
    torch.Size([2, 2, 1, 2])

Note:

The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.

也就是说, squeeze 所返回的 tensor 与 输入 tensor 共享内存, 所以如果改变其中一项的值另一项也会随着改变.

torch.unsqueeze

torch.unsqueeze(input, dim, out=None) → Tensor

Note: 这里与 squeeze 不同的是 unsqueeze 必须指定维度.

同时, unsqueeze 所返回的 tensor 与 输入的 tensor 也是共享内存的.

>>> import torch
>>> a = torch.zeros([2, 2])
>>> a.shape
torch.Size([2, 2])
>>> a
tensor([[0., 0.],
[0., 0.]])
>>> b = torch.unsqueeze(a, dim=0)
>>> b.shape
torch.Size([1, 2, 2])
>>> b
tensor([[[0., 0.],
[0., 0.]]])
>>> b[0, 0, 1] = 1
>>> b
tensor([[[0., 1.],
[0., 0.]]])
>>> a
tensor([[0., 1.],
[0., 0.]])
>>> b = torch.unsqueeze(a, dim=1)
>>> b.shape
torch.Size([2, 1, 2])
>>> b
tensor([[[0., 1.]],
[[0., 0.]]])
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: