您的位置:首页 > 其它

pytorch 使用

2017-05-20 11:08 573 查看

1 DataParallel

from torch.nn import DataParallel
net = DataParallel(net)


可以实现模块级别(?好处具体是啥不大懂)的并行计算,可以将一个模块forward部分分到各个gpu去计算,然后backwards时,合并gradients 到original module。

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)


2 DataLoader

其实这里trainset已经包含数据集了,dataloader只是定义输入网络的一些参数,入batch_size等等。



3 Transform

对数据集进行的操作



compose函数会将多个transforms包在一起。

参考:

http://www.jianshu.com/p/8da9b24b2fb6
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: