PyTorch中张量的操作:拼接、切分、比较、索引和变换
2020-03-05 16:50
225 查看
张量的拼接
torch.cat()
torch.cat(tensors, dim=0, out=None)
功能: 将张量按维度dim进行拼接
- tensors: 张量序列
- dim : 要拼接的维度
t = torch.ones((2, 3)) q = torch.zeros((2, 3)) t0 = torch.cat([t, q], 0) t1 = torch.cat((t, q), dim=1) print(t0, t0.shape) print(t1, t1.shape)
tensor([[1., 1., 1.], [1., 1., 1.], [0., 0., 0.], [0., 0., 0.]]) torch.Size([4, 3]) tensor([[1., 1., 1., 0., 0., 0.], [1., 1., 1., 0., 0., 0.]]) torch.Size([2, 6])
torch.stack()
torch.stack(tensors, dim=0, out=None)
功能: 在新创建的维度dim上进行拼接
- tensors:张量序列
- dim :要拼接的维度
t = torch.ones((3, 4)) q = torch.zeros((3, 4)) t0 = torch.stack([t, q], dim=0) t1 = torch.stack([t, q], dim=1) t2 = torch.stack([t, q], dim=2) print(t0, t0.shape) print(t1, t1.shape) print(t2, t2.shape)
tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]) torch.Size([2, 3, 4]) tensor([[[1., 1., 1., 1.], [0., 0., 0., 0.]], [[1., 1., 1., 1.], [0., 0., 0., 0.]], [[1., 1., 1., 1.], [0., 0., 0., 0.]]]) torch.Size([3, 2, 4]) tensor([[[1., 0.], [1., 0.], [1., 0.], [1., 0.]], [[1., 0.], [1., 0.], [1., 0.], [1., 0.]], [[1., 0.], [1., 0.], [1., 0.], [1., 0.]]]) torch.Size([3, 4, 2])
张量的切分
torch.chunk()
torch.chunk(input, chunks, dim=0)
功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除,最后一份张量小于其他张量
- input: 要切分的张量
- chunks : 要切分的份数
- dim : 要切分的维度
t = torch.ones((2, 7)) list_t = torch.chunk(t, chunks=3, dim=1) for i, ten in enumerate(list_t): print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量: tensor([[1., 1., 1.], [1., 1., 1.]]) torch.Size([2, 3]) 第2个张量: tensor([[1., 1., 1.], [1., 1., 1.]]) torch.Size([2, 3]) 第3个张量: tensor([[1.], [1.]]) torch.Size([2, 1])
torch.split()
torch.split(tensor, split_size_or_sections, dim=0)
功能: 将张量按维度dim进行切分
返回值: 张量列表
- tensor: 要切分的张量
- split_size_or_sections : 为int时,表示每一份的长度;为list时,按list元素切分
- dim : 要切分的维度
t = torch.ones((2, 7)) list_t = torch.split(t, 3, dim=1) for i, ten in enumerate(list_t): print("第{}个张量:\n{}".format(i+1, ten), ten.shape) print("\n") list_t = torch.split(t, [3, 4], dim=1) for i, ten in enumerate(list_t): print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量: tensor([[1., 1., 1.], [1., 1., 1.]]) torch.Size([2, 3]) 第2个张量: tensor([[1., 1., 1.], [1., 1., 1.]]) torch.Size([2, 3]) 第3个张量: tensor([[1.], [1.]]) torch.Size([2, 1]) 第1个张量: tensor([[1., 1., 1.], [1., 1., 1.]]) torch.Size([2, 3]) 第2个张量: tensor([[1., 1., 1., 1.], [1., 1., 1., 1.]]) torch.Size([2, 4])
张量的比较
torch.ge(),torch.gt(),torch.le(),torch.lt()
torch.ge(input, other, out=None)
功能: input中逐元素与other进行比较,满足:ge >=; gt >; le <=; lt <时,返回True
返回值: 与input同形状的布尔类型张量
- input:被比较的张量
- other:可以是张量,数值,布尔,input中逐元素与其进行比较
t = torch.randint(0, 10, [3, 3]) m = t.ge(5) print(t) print(m)
tensor([[1, 6, 5], [6, 5, 4], [0, 4, 4]]) tensor([[False, True, True], [ True, True, False], [False, False, False]])
张量的索引
torch.index_select()
torch.index_select(input, dim, index, out=None)
功能: 在维度dim上,按index索引数据
返回值: 索引得到的数据拼接的张量
- input: 要索引的张量
- dim: 要索引的维度
- index : 要索引数据的序号组成的张量,dtype须为torch.long
t = torch.randint(0, 10, [3, 3]) idx = torch.tensor([0, 2], dtype=torch.long) sel = torch.index_select(t, 0, idx) print(t) print(sel)
tensor([[0, 7, 0], [8, 3, 1], [2, 7, 9]]) tensor([[0, 7, 0], [2, 7, 9]])
torch.masked_select()
torch.masked_select(input, mask, out=None)
功能: 按mask中的True进行索引
返回值: 一维张量
- input: 要索引的张量
- mask: 与input同形状的布尔类型张量
t = torch.randint(0, 10, [3, 3]) mask = t.ge(5) sel = torch.masked_select(t, mask) print(t) print(mask) print(sel)
tensor([[1, 6, 5], [6, 5, 4], [0, 4, 4]]) tensor([[False, True, True], [ True, True, False], [False, False, False]])tensor([6, 5, 6, 5])
张量的变换
torch.reshape()
torch.reshape(input, shape)
功能: 变换张量形状
注意事项: 当张量在内存中是连续的时,新张量与input共享数据内存。这种共享与out不同,out是整个tensor都共享内存,相当于别名;reshape是仅data共享内存。改变一个张量的数据,另一个张量会跟着改变
- input: 要变换的张量
- shape: 新张量的形状
t = torch.randperm(8) re1 = torch.reshape(t, (2, 4)) re2 = torch.reshape(t, (-1, 4)) print(t) print(re1) print(re2) t[0] = 100 re2[1, 1] = 100 print(id(t.data), id(re1.data), id(re2.data)) print(re1)
tensor([0, 7, 2, 6, 3, 5, 4, 1]) tensor([[0, 7, 2, 6], [3, 5, 4, 1]]) tensor([[0, 7, 2, 6], [3, 5, 4, 1]]) 3039469824264 3039469824264 3039469824264 tensor([[100, 7, 2, 6], [ 3, 100, 4, 1]])
torch.transpose()
torch.transpose(input, dim0, dim1)
功能: 交换张量的两个维度。在图像的预处理中常用,有时读取的图像数据是(c, h, w),但是我们常用的是(h, w, c),就需要用此方法把channel和width变换,再把width和height变换
- input: 要变换的张量
- dim0: 要交换的维度
- dim1: 要交换的维度
t = torch.rand((2, 3, 4)) tr = torch.transpose(t, 1, 0) print(t, t.shape) print(tr, tr.shape)
tensor([[[0.4973, 0.0644, 0.7269, 0.9305], [0.4711, 0.1117, 0.1751, 0.4904], [0.9865, 0.7374, 0.9201, 0.5733]], [[0.4911, 0.4571, 0.9985, 0.7298], [0.5078, 0.0928, 0.1655, 0.8740], [0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([2, 3, 4]) tensor([[[0.4973, 0.0644, 0.7269, 0.9305], [0.4911, 0.4571, 0.9985, 0.7298]], [[0.4711, 0.1117, 0.1751, 0.4904], [0.5078, 0.0928, 0.1655, 0.8740]], [[0.9865, 0.7374, 0.9201, 0.5733], [0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([3, 2, 4])
torch.t()
torch.t(input)
功能: 2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)
torch.squeeze()
torch.squeeze(input, dim=None, out=None)
功能: 压缩长度为1的维度(轴)
- dim: 若为None,移除所有长度为1的轴; 若指定维度,当且仅当该轴长度为1时,可以被移除
t = torch.rand((1, 2, 3, 1)) sq = torch.squeeze(t) sq0 = torch.squeeze(t, 0) sq1 = torch.squeeze(t, 1) print(t.shape) print(sq.shape) print(sq0.shape) print(sq1.shape)
torch.Size([1, 2, 3, 1]) torch.Size([2, 3]) torch.Size([2, 3, 1]) torch.Size([1, 2, 3, 1])
torch.unsqueeze()
torch.usqueeze(input, dim, out=None)
功能:依据dim扩展维度
- dim: 扩展的维度
t = torch.rand((2, 3)) sq = torch.unsqueeze(t, 0) print(t.shape) print(sq.shape)
torch.Size([2, 3]) torch.Size([1, 2, 3])
- 点赞
- 收藏
- 分享
- 文章举报
相关文章推荐
- JavaScript 字符串操作(给索引查字符/给字符查索引/uri 编码和解码/字符串拼接/字符串截取/去掉空白/替换/变为数组/查找字符串中所有匹配项)
- 字符串操作(拷贝,比较,拼接等函数)
- 比较memcpy与指针操作和索引操作复制时的效率
- python字符串操作,截取,拼接,替换,删除,比较,查找
- pytorch 张量基本操作
- pytorch中的部分张量操作:输出满足条件的张量与合并多维张量
- Pytorch-2:张量关于shape的操作
- pytorch的比较操作
- 表、视图、索引的创建、修改、删除操作等
- string之间的拼接比较
- phoenix索引操作
- MySQL(四)索引的操作
- Oracle的like和substr对于索引的操作
- 用java语言,操作给定的二叉树,将其变换为源二叉树的镜像(递归和循环两种方法)
- Python 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转换、分割等)
- 四种操作xml的方式: SAX, DOM, JDOM , DOM4J的比较
- 矩阵堆栈的操作、组合变换
- java操作xml方式比较与详解(DOM、SAX、JDOM、DOM4J)
- 参数化操作数据库,不用拼接字符串
- RACSignal 所有变换操作底层实现分析(1)