动态导入模块,加载预训练模型,nn.Sequential函数里面必须是a Module subclass,不能是一个列表或者是其他的迭代器、生成器,虽然这里面包含了Module的子类
2018-10-30 12:29
2276 查看
[code]class RES(nn.Module): def __init__(self): super(RES, self).__init__() self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) self.bn1=nn.BatchNorm2d(64) self.relu=nn.ReLU(inplace=True) self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1) self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False) self.bn2=nn.BatchNorm2d(128) def forward(self,x): x=self.conv1(x) x=self.bn1(x) x=self.relu(x) x=self.maxpool(x) x=self.conv2(x) x=self.bn2(x) return x model=RES() glb = nn.Sequential(*list(model.children())[:4])
有两点数据的说明:这个类继承了Module一定要用super函数
nn.Sequential函数里面的参数一定是Module的子类,而list:list is not a Module subclass。所以不能当做参数,当然model.children()也是一样:Module.children is not a Module subclass。这里的*就起了作用,将list或者children的内容迭代的一个一个的传进去,效果如下:
当然,我们还可以像最上面的那样,选取里面的几个Module,例如[:4]也就是第0个到第3个.
动态导入模块,使用importlib.import_module函数实际上是import了一个叫做resnet的文件,下面的语句相当于 import xxx as resnet
当然这里的xxx是该文件的实际路径
[code]import importlib resnet = importlib.import_module("torchvision.models.resnet") resnet18=resnet.resnet18() resnet34=resnet.resnet34() resnet50=resnet.resnet50() resnet101=resnet.resnet101() resnet152=resnet.resnet152()
其他的模块有:
[code]""" alexnet文件 """ alexnet=importlib.import_module("torchvision.models.alexnet") alexnet=alexnet.alexnet() nn.Sequential(*alexnet.children()) """ vgg文件 """ vgg=importlib.import_module("torchvision.models.vgg") vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他们的bn形式 # vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn() nn.Sequential(*vgg16.children()) """ densenet文件 """ densenet=importlib.import_module("torchvision.models.densenet") densenet121=densenet.densenet121() # densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161() nn.Sequential(*densenet121.children()) """ inception文件 """ inception=importlib.import_module("torchvision.models.inception") inception_v3=inception.inception_v3() nn.Sequential(*inception_v3.children()) """ squeezenet文件 """ squeezenet=importlib.import_module("torchvision.models.squeezenet") squeezenet1_0=inception.squeezenet1_0() # squeezenet1_0=inception.squeezenet1_1() nn.Sequential(*squeezenet1_0.children())
还有一种导入方式,是比较常用的,推荐的:
[code]import torchvision.models as models models.squeezenet1_0() """ models后面直接接的是网络 models的__init__文件如下 """ from .alexnet import * from .resnet import * from .vgg import * from .squeezenet import * from .inception import * from .densenet import * """ 可以看出来,导入的是这5个文件里面的函数(类) *代表想对应文件的__all__,下面是各个文件的该属性以及训练好的权重 """ # alexnet __all__ = ['AlexNet', 'alexnet'] model_urls = { 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', } # resnet __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } # vgg __all__ = [ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',] model_urls = { 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', } # squeezenet __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] model_urls = { 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', } # inception __all__ = ['Inception3', 'inception_v3'] model_urls = { # Inception v3 ported from TensorFlow 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', } # densenet __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] model_urls = { 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', }
所有的模型默认都是不加载预训练模型参数的,怎么加载预训练模型参数呢?很简单,就在括号里面的pretrained设置成True,如果仅仅是需要该结构而不需要预训练模型参数作为初始化,那么pretrained=False。
[code]resnet50 = models.resnet50(pretrained=True)
推荐!这里有一篇比较综合https://blog.csdn.net/weixin_41278720/article/details/80759933
其中可以补充一点就是将参数进行下载,相比加载模型来说更加的节省资源
[code] import torch.utils.model_zoo as model_zoo def _load_pretrained_model(self): pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', '/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights') model_dict = {} state_dict = self.state_dict() for k, v in pretrain_dict.items(): if k in state_dict: model_dict[k] = v state_dict.update(model_dict) self.load_state_dict(state_dict)
阅读更多
相关文章推荐
- 【c语言】为下面的函数原型编写函数定义,这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。
- java动态加载指定的类或者jar包反射调用其方法-涉及其他jar中的类就报ClassNotFound问题分析及解决思路
- 【C语言】为下面的函数原型编写函数定义: int ascii_to_integer(char *str); 这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。
- 编写一个函数,从标准输入读取一列整数, 把这些值存储在一个动态分配的数组中并返回这个数组。 函数通过观察EOF判断输入列表是否结束。 数组的第一个数是数组包含的值的个数, 它的后面就是这些整数值。
- C 这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。如果字符串参数包含了任何非数字字符,函数就返回零。
- 为下面的函数原型编写函数定义: int ascii_to_integer(char *str); 这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。如果字符串参数
- 编写一个函数,从标准输入读取一列整数,把这些值存储于一个动态分配的数组中并返回这个数组。函数通过观察EOF判断输入列表是否结束。数组的第一个数是数组包含的值的个数,他的后面就是这些整数值。
- 字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。如果字符串参数包含了任何非数字字符,函数就返回零
- . 有一个一维数组,里面存储整形数据,请写一个函数,将他们按从大到小的顺序排列,要求执行效率高,并说明如何改善执行效率(该函数必须自己实现,不能使用php函数)。
- eclipse导入一个项目后,不能加载到tomcat里面
- 写一个字符串函数,这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。
- int ascii_to_integer(char *str); 这个字符串参数必须包含一个或者多个数字,函数应该把这些数字字符转换为整数并返回这个整数。
- java动态加载指定的类或者jar包反射调用其方法-涉及其他jar中的类就报ClassNotFound问题分析及解决思路
- 创建一个模块calculator.py,完成任意两个数的加(add)、减(sub)、乘(mult)、除(div)运算;导入该模块,分别调用其中的函数,完成如下操作: 1、25+56 2、86-68 3
- 使用git submodule管理一个需要多个分立开发或者第三方repo的项目
- 利用SubclassDlgItem函数动态连接控件和控件对象
- 使用git submodule管理一个需要多个分立开发或者第三方repo的项目
- PHP 5.0不能加载动态模块的解决方法
- IsapiModule或CgiModule必须在模块列表中
- Python 生成器函数,生成器表达式,迭代器,列表解析