您的位置:首页 > 编程语言 > Python开发

pytorch合并与拆分

2019-05-13 02:40 435 查看

import torch

import numpy as np

 

#自动扩张,或者叫做广播,不需要拷贝数据.

#如果前面没有维度,则在前面插入一个新的维度;

#对齐时默认从后面开始对齐.

#broading cast 可以简化运算并且减少内存拷贝.

a = torch.randn(4,3)

b = torch.rand(4,3)

c = a + b

c = torch.randn(1,3)

d = a+ c

 

#拼接1:

a = torch.randn(5,32,48)

b = torch.randn(4,32,48)

c = torch.cat([a,b],dim=0)#在第0维度合并

print(c.shape)#torch.Size([9, 32, 48])

print(torch.cat([a,b]).shape)#默认拼接按照0维

 

#拼接2 stack会创建新的维度,注意其形状必须匹配

a = torch.randn(32,8)

b = torch.randn(32,8)

d = torch.stack([a,b],dim=2)#torch.Size([32, 8, 2])

print(d.shape)

 

#拆分split 长度拆分,如[1,2,3,4,5,6]指定拆分长度2,则拆分为三个单元

a = torch.randn(2,32,8)

b = a.split(1,dim=0)#第0个维度拆分

print(type(b))

 

#按数量区分

print(a.chunk(2,dim=0))

(adsbygoogle = window.adsbygoogle || []).push({});
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  NumPy PyTorch