pytorch 共享参数的示例
2019-08-17 13:37
2835 查看
在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。
例子1:
class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5)) def forward(self, x): x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1) x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1, groups=1) return x
上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。
例子2:
class LinearNet(nn.Module): def __init__(self): super(LinearNet, self).__init__() self.linear_weight = nn.Parameter(torch.randn(3, 3)) def forward(self, x): x = nn.functional.linear(x, self.linear_weight) x = nn.functional.linear(x, self.linear_weight.t()) return x
这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。
例子3:
class LinearNet2(nn.Module): def __init__(self): super(LinearNet2, self).__init__() self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]])) def forward(self, x): x = x.mm(self.w) x = x.mm(self.w.t()) return x
这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。
以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考
您可能感兴趣的文章:
相关文章推荐
- pytorch 共享参数方法
- pytorch 共享参数方法
- 水晶报表参数编程示例代码
- 安卓存储数据和文件系列2:共享参数(sharedpreferences)方式
- Android 数据存储(二) 共享参数存储
- 共享SQL区、私有SQL区与游标 (提到参数DB_FILE_MULTIBLOCK_READ_COUNT)
- nginx编译参数和示例
- PHP使用Redis实现Session共享的实现示例
- C语言可变参数使用示例
- 进程通信系列(13)共享内存系统调用与代码示例
- SharedPreferences共享参数
- MyBatis多参数传递之注解方式示例--转
- MySql连接数据库常用参数及代码示例
- mybatis 动态传入表名,表名作为参数示例
- java 不定参数使用示例
- Android之数据存储共享参数实现系统设置操作功能(二)
- DeepLearning4j的StackVertex实现参数共享
- 共享维度和多层次的示例
- js动态添加事件并可传参数示例代码
- MySQL备份命令mysqldump参数说明与示例