您的位置:首页 > 其它

PyTorch中张量的操作:拼接、切分、比较、索引和变换

2020-03-05 16:50 225 查看

张量的拼接

torch.cat()

torch.cat(tensors,
dim=0,
out=None)

功能: 将张量按维度dim进行拼接

  • tensors: 张量序列
  • dim : 要拼接的维度
t = torch.ones((2, 3))
q = torch.zeros((2, 3))
t0 = torch.cat([t, q], 0)
t1 = torch.cat((t, q), dim=1)
print(t0, t0.shape)
print(t1, t1.shape)
tensor([[1., 1., 1.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.]]) torch.Size([4, 3])
tensor([[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.]]) torch.Size([2, 6])

torch.stack()

torch.stack(tensors,
dim=0,
out=None)

功能: 在新创建的维度dim上进行拼接

  • tensors:张量序列
  • dim :要拼接的维度
t = torch.ones((3, 4))
q = torch.zeros((3, 4))
t0 = torch.stack([t, q], dim=0)
t1 = torch.stack([t, q], dim=1)
t2 = torch.stack([t, q], dim=2)
print(t0, t0.shape)
print(t1, t1.shape)
print(t2, t2.shape)
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],

[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]]) torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
[0., 0., 0., 0.]],

[[1., 1., 1., 1.],
[0., 0., 0., 0.]],

[[1., 1., 1., 1.],
[0., 0., 0., 0.]]]) torch.Size([3, 2, 4])
tensor([[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],

[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],

[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]]]) torch.Size([3, 4, 2])

张量的切分

torch.chunk()

torch.chunk(input,
chunks,
dim=0)

功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除,最后一份张量小于其他张量

  • input: 要切分的张量
  • chunks : 要切分的份数
  • dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.chunk(t, chunks=3, dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
[1.]]) torch.Size([2, 1])

torch.split()

torch.split(tensor,
split_size_or_sections,
dim=0)

功能: 将张量按维度dim进行切分
返回值: 张量列表

  • tensor: 要切分的张量
  • split_size_or_sections : 为int时,表示每一份的长度;为list时,按list元素切分
  • dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.split(t, 3, dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)

print("\n")

list_t = torch.split(t, [3, 4], dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
[1.]]) torch.Size([2, 1])
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]]) torch.Size([2, 4])

张量的比较

torch.ge(),torch.gt(),torch.le(),torch.lt()

torch.ge(input,
other,
out=None)

功能: input中逐元素与other进行比较,满足:ge >=; gt >; le <=; lt <时,返回True
返回值: 与input同形状的布尔类型张量

  • input:被比较的张量
  • other:可以是张量,数值,布尔,input中逐元素与其进行比较
t = torch.randint(0, 10, [3, 3])
m = t.ge(5)
print(t)
print(m)
tensor([[1, 6, 5],
[6, 5, 4],
[0, 4, 4]])
tensor([[False,  True,  True],
[ True,  True, False],
[False, False, False]])

张量的索引

torch.index_select()

torch.index_select(input,
dim,
index,
out=None)

功能: 在维度dim上,按index索引数据
返回值: 索引得到的数据拼接的张量

  • input: 要索引的张量
  • dim: 要索引的维度
  • index : 要索引数据的序号组成的张量,dtype须为torch.long
t = torch.randint(0, 10, [3, 3])
idx = torch.tensor([0, 2], dtype=torch.long)
sel = torch.index_select(t, 0, idx)
print(t)
print(sel)
tensor([[0, 7, 0],
[8, 3, 1],
[2, 7, 9]])
tensor([[0, 7, 0],
[2, 7, 9]])

torch.masked_select()

torch.masked_select(input,
mask,
out=None)

功能: 按mask中的True进行索引
返回值: 一维张量

  • input: 要索引的张量
  • mask: 与input同形状的布尔类型张量
t = torch.randint(0, 10, [3, 3])
mask = t.ge(5)
sel = torch.masked_select(t, mask)
print(t)
print(mask)
print(sel)
tensor([[1, 6, 5],
[6, 5, 4],
[0, 4, 4]])
tensor([[False,  True,  True],
[ True,  True, False],
[False, False, False]])tensor([6, 5, 6, 5])

张量的变换

torch.reshape()

torch.reshape(input,
shape)

功能: 变换张量形状
注意事项: 当张量在内存中是连续的时,新张量与input共享数据内存。这种共享与out不同,out是整个tensor都共享内存,相当于别名;reshape是仅data共享内存。改变一个张量的数据,另一个张量会跟着改变

  • input: 要变换的张量
  • shape: 新张量的形状
t = torch.randperm(8)
re1 = torch.reshape(t, (2, 4))
re2 = torch.reshape(t, (-1, 4))
print(t)
print(re1)
print(re2)
t[0] = 100
re2[1, 1] = 100
print(id(t.data), id(re1.data), id(re2.data))
print(re1)
tensor([0, 7, 2, 6, 3, 5, 4, 1])
tensor([[0, 7, 2, 6],
[3, 5, 4, 1]])
tensor([[0, 7, 2, 6],
[3, 5, 4, 1]])
3039469824264 3039469824264 3039469824264
tensor([[100,   7,   2,   6],
[  3, 100,   4,   1]])

torch.transpose()

torch.transpose(input,
dim0,
dim1)

功能: 交换张量的两个维度。在图像的预处理中常用,有时读取的图像数据是(c, h, w),但是我们常用的是(h, w, c),就需要用此方法把channel和width变换,再把width和height变换

  • input: 要变换的张量
  • dim0: 要交换的维度
  • dim1: 要交换的维度
t = torch.rand((2, 3, 4))
tr = torch.transpose(t, 1, 0)
print(t, t.shape)
print(tr, tr.shape)
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
[0.4711, 0.1117, 0.1751, 0.4904],
[0.9865, 0.7374, 0.9201, 0.5733]],

[[0.4911, 0.4571, 0.9985, 0.7298],
[0.5078, 0.0928, 0.1655, 0.8740],
[0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([2, 3, 4])
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
[0.4911, 0.4571, 0.9985, 0.7298]],

[[0.4711, 0.1117, 0.1751, 0.4904],
[0.5078, 0.0928, 0.1655, 0.8740]],

[[0.9865, 0.7374, 0.9201, 0.5733],
[0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([3, 2, 4])

torch.t()

torch.t(input)

功能: 2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)

torch.squeeze()

torch.squeeze(input,
dim=None,
out=None)

功能: 压缩长度为1的维度(轴)

  • dim: 若为None,移除所有长度为1的轴; 若指定维度,当且仅当该轴长度为1时,可以被移除
t = torch.rand((1, 2, 3, 1))
sq = torch.squeeze(t)
sq0 = torch.squeeze(t, 0)
sq1 = torch.squeeze(t, 1)
print(t.shape)
print(sq.shape)
print(sq0.shape)
print(sq1.shape)
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])

torch.unsqueeze()

torch.usqueeze(input,
dim,
out=None)

功能:依据dim扩展维度

  • dim: 扩展的维度
t = torch.rand((2, 3))
sq = torch.unsqueeze(t, 0)
print(t.shape)
print(sq.shape)
torch.Size([2, 3])
torch.Size([1, 2, 3])
  • 点赞
  • 收藏
  • 分享
  • 文章举报
Sakura樱_子于 发布了11 篇原创文章 · 获赞 0 · 访问量 2084 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: