您的位置:首页 > 产品设计 > UI/UE

动态导入模块,加载预训练模型,nn.Sequential函数里面必须是a Module subclass,不能是一个列表或者是其他的迭代器、生成器,虽然这里面包含了Module的子类

2018-10-30 12:29 2276 查看

 

[code]class RES(nn.Module):
def __init__(self):
super(RES, self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1=nn.BatchNorm2d(64)
self.relu=nn.ReLU(inplace=True)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False)
self.bn2=nn.BatchNorm2d(128)
def forward(self,x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.maxpool(x)
x=self.conv2(x)
x=self.bn2(x)
return x

model=RES()
glb = nn.Sequential(*list(model.children())[:4])

有两点数据的说明:这个类继承了Module一定要用super函数

nn.Sequential函数里面的参数一定是Module的子类,而list:list is not a Module subclass。所以不能当做参数,当然model.children()也是一样:Module.children is not a Module subclass。这里的*就起了作用,将list或者children的内容迭代的一个一个的传进去,效果如下:

 当然,我们还可以像最上面的那样,选取里面的几个Module,例如[:4]也就是第0个到第3个.

 

动态导入模块,使用importlib.import_module函数实际上是import了一个叫做resnet的文件,下面的语句相当于 import xxx as resnet

当然这里的xxx是该文件的实际路径

[code]import importlib
resnet = importlib.import_module("torchvision.models.resnet")
resnet18=resnet.resnet18()
resnet34=resnet.resnet34()
resnet50=resnet.resnet50()
resnet101=resnet.resnet101()
resnet152=resnet.resnet152()

其他的模块有:

[code]"""
alexnet文件
"""
alexnet=importlib.import_module("torchvision.models.alexnet")
alexnet=alexnet.alexnet()
nn.Sequential(*alexnet.children())

"""
vgg文件
"""
vgg=importlib.import_module("torchvision.models.vgg")
vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他们的bn形式
# vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn()
nn.Sequential(*vgg16.children())

"""
densenet文件
"""
densenet=importlib.import_module("torchvision.models.densenet")
densenet121=densenet.densenet121()
# densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161()
nn.Sequential(*densenet121.children())

"""
inception文件
"""
inception=importlib.import_module("torchvision.models.inception")
inception_v3=inception.inception_v3()
nn.Sequential(*inception_v3.children())

"""
squeezenet文件
"""
squeezenet=importlib.import_module("torchvision.models.squeezenet")
squeezenet1_0=inception.squeezenet1_0()
# squeezenet1_0=inception.squeezenet1_1()
nn.Sequential(*squeezenet1_0.children())

还有一种导入方式,是比较常用的,推荐的:

[code]import torchvision.models as models
models.squeezenet1_0()

"""
models后面直接接的是网络
models的__init__文件如下
"""
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
"""
可以看出来,导入的是这5个文件里面的函数(类)
*代表想对应文件的__all__,下面是各个文件的该属性以及训练好的权重
"""
# alexnet
__all__ = ['AlexNet', 'alexnet']
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
# resnet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
# vgg
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',]
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
# squeezenet
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}
# inception
__all__ = ['Inception3', 'inception_v3']
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
# densenet
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}

所有的模型默认都是不加载预训练模型参数的,怎么加载预训练模型参数呢?很简单,就在括号里面的pretrained设置成True,如果仅仅是需要该结构而不需要预训练模型参数作为初始化,那么pretrained=False。

[code]resnet50 = models.resnet50(pretrained=True)

推荐!这里有一篇比较综合https://blog.csdn.net/weixin_41278720/article/details/80759933

其中可以补充一点就是将参数进行下载,相比加载模型来说更加的节省资源

[code]    import torch.utils.model_zoo as model_zoo

def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)

 

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐