您的位置:首页 > 产品设计 > UI/UE

mmdetection源码笔记(二):创建网络模型之registry.py和builder.py解读(上)

2019-08-14 16:52 696 查看
版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。 本文链接:https://blog.csdn.net/qq_41375609/article/details/99549794

引言:

在上篇文章中,讲了train.py训练文件,主要是读取命令行函数和主函数main。main主要先做了一些config,work_dir以及log等操作(这些操作都是从命令行获得的,或者从命令行带有的文件里得到的参数等。)。最主要的三个步骤就是调用build_detector()来创建模型,然后同样调用build_dataset()对数据集创建模型,然后在训练检测器train_detector()。
注:build_dataset()和build_detector()不在同一个builder.py中实现,所以以下的builder.py实现的是build_detector(),是在mmdet/model/下的py文件。
具体详情看

本篇文章主要就是讲一下,搭建模型的思路,以及

registry.py
builder.py
中各个函数块的作用。

注:builder.py是在mmdet/model文件夹下,是用来创建BACKBONES、NECKS、ROI_EXTRACTORS、SHARED_HEADS、HEADS、LOSSES、DETECTORS的模型的。而关于build_dataset()(在mmdet/datasets/builder.py中),在后面讲到数据集的时候再来讲它。

在mmdet/utils文件夹下的

registry.py
为主要的实现过程,后面详细讲解。
先来看在mmdet/models文件夹下的
registry.py
,较简单,代码如下:

# -*- coding: utf-8 -*-
from mmdet.utils import Registry

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
#类的实例化,Registry是一个类,传入的是一个字符串。该字符串为Registry类的name属性值

举个例子:

DETECTORS
为注册表Registry的实例化对象,
DETECTORS .name = 'detector'
,Registry类的定义在mmdet/utils/文件中。

所以,根据上面代码,我们就应该知道了,不止一个名为

DETECTORS
的注册表Registry,后面还会有名为NECKS、ROI_EXTRACTORS 、SHARED_HEADS 、HEADS 、LOSSES 的注册表,这些注册表下的_module_dict属性,则是用来存对应的相同类对象的,举个例子:比如DETECTORS的_module_dict下就有可能有:Faster R-CNN、Cascade R-CNN、FPN、HTC等常见的检测器。到这或许你就明白了注册表的作用咯

而在mmdet/utils/Registry.py中,有一个类

Registry
的定义和一个方法:
build_from_cfg()
的实现。
build_from_cfg()
方法的作用是从 congfig/py配置文件中获取字典数据,创建module(其实也就是一个class类),然后将这个module添加到之前创建的注册表Registry的属性_module_dict中(这是一个字典,key为类名,value为具体的类),返回值是一个实例化后的类对象。

所以,可以这样理解,从config/py配置文件中,将字典提取出来,然后为其映射成一个类,放进Registry对象的_module_dict属性中。(具体看下面的代码)

Registry.py文件

以下代码分三部分

Part one:

inspect模块是针对模块,类,方法,功能等对象提供些有用的方法。例如可以帮助我们检查类的内容,检查方法的代码,提取和格式化方法的参数等。

# -*- coding: utf-8 -*-

3ff7
import inspect
import mmcv
Part two:

通过前面第一段的代码段,我们知道

DETECTORS = Registry('detector')

detector是干什么的 ???
其实,
DETECTORS = Registry('detector')
只是注册了一个对象名为DETECTORS ,属性name为detector的对象。然后用属性_module_dict 来保存config配置文件中的对应的字典数据所对应的class类(看第三部分代码)。请看如下类Registry的定义代码:

class Registry(object):

def __init__(self, name):		#此处的self,是个对象(Object),是当前类的实例,name即为传进来的'detector'值
self._name = name
self._module_dict = dict()  #定义的属性,是一个字典

def __repr__(self):
#返回一个可以用来表示对象的可打印字符串,可以理解为java中的toString()。
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str

@property						#把方法变成属性,通过self.name 就能获得name的值。
def name(self):
return self._name
#因为没有定义它的setter方法,所以是个只读属性,不能通过 self.name = newname进行修改。
@property
def module_dict(self):
#同上,通过self.module_dict可以获取属性_module_dict,也是只读的
return self._module_dict

def get(self, key):
#普通方法,获取字典中指定key的value,_module_dict是一个字典,然后就可以通过self.get(key),获取value值
return self._module_dict.get(key, None)

def _register_module(self, module_class):
#关键的一个方法,作用就是Register a module.
#在model文件夹下的py文件中,里面的class定义上面都会出现 @DETECTORS.register_module,意思就是将类当做形参,
#将类送入了方法register_module()中执行。@的具体用法看后面解释。
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):		  #判断是否为类,是类的话,就为True,跳过判断
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__ 		  #获取类名
if module_name in self._module_dict:          #看该类是否已经登记在属性_module_dict中
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class #在module中dict新增key和value。key为类名,value为类对象

def register_module(self, cls): 				  #对上面的方法,修改了名字,添加了返回值,即返回类本身
self._register_module(cls)
return cls
note:

@的含义:
Python当解释器读到@的这样的修饰符之后,会先解析@后的内容,直接就把@下一行的函数或者类作为@后边的函数的参数,然后将返回值赋值给下一行修饰的函数对象。
在网上看到一个这样的例子:

def a(x):
if x==2:
return 4
return 6
def b(x):
if x==1:
return 2
return 3
@a
@b
def c():
return 1

python会按照自下而上的顺序把各自的函数结果作为下一个函数(上面的函数)的形参输入,也就是a(b(c()))。

Part three:

以下我们通过配置文件

cascade_rcnn_r50_fpn_1x.py
进行讲解 build 模型的过程。
在train中,最先执行Registry的是DETECTORS,传入的参数是配置文件中的model字典。

#在 train.py中
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
#在builder.py中
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

所以,后面出现的参数cfg,指的就是配置文件中的model字典。下面是model字典的部分截图。

我们继续往下看
先看build_from_cfg()方法的参数:
Args:

  • cfg (dict): Config dict. It should at least contain the key “type”.
    这个cfg就是py配置文件中的字典。在py配置文件中,基本上dict都会有一个key为"type",当然也有不是的,不是的,这一步就不会执行,也就不会为他创建module。也就是这边创建成module的dict,都必须有key为"type"才可以创建(这里,我们主要讲的是注册表DETECTORS,所以此时cfg对应的是配置文件中的model字典,看上面截图)。
    举个例子:比如

    type='CascadeRCNN'
    ,后面我们会知道,这个value为
    "CascadeRCNN"
    的,其实就是models文件夹中某py文件中的类名,他们通过@DETECTORS.register_module,将类名当做形参,传入register_module。并保存下来。

  • registry (:obj:

    Registry
    ): The registry to search the type from.

  • default_args (dict, optional): Default initialization arguments.

def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None #两个是断言,相当于判断,否的话抛出异常。

args = cfg.copy()						#args相当于temp中间变量,是个字典。
obj_type = args.pop('type') 			#字典的pop作用:移除序列中key为‘type’的元素,并且返回该元素的值
if mmcv.is_str(obj_type):

obj_type = registry.get(obj_type
4000
)	#获取obj_type的value。
#如果obj_type已经注册到注册表registry中,即在属性_module_dict中,则obj_type 不为None
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():#items()返回字典的键值对用于遍历
args.setdefault(name, value)
#将default_args的键值对加入到args中,将模型和训练配置进行整合,然后送入类中返回

return obj_type(**args)

obj_type(**args)
,* *args是将字典unpack得到各个元素,分别与形参匹配送入函数中;看上面model的截图,所以这边,其实就是将除了’type’的所有字段,当做形参,送入了名为CascadeRCNN()的类中(type = ’CascadeRCNN‘)。所以字典里的key就是类中的属性?继续看下面。

根据Cascade R-CNN的例子,我们在models/detectors找cascade_rcnn的py文件
参考里面的参数时,直接打开对应的cascade_rcnn配置文件,在init中,里面的参数则
对应了配置文件中的字典名。下面两个截图分别是配置文件cascade_rcnn.py和model/detectors/cascade_rcnn.py中的类定义。


注意的是,在py配置文件中,好多py文件中都有type = ‘CascadeRCNN’,所以有些参数和属性对不上很正常(毕竟已经设置为None了),因为这个参数可能是其他的cascade R-CNN里面的字典。
所以,我们在训练时,测试时,就要给出配置文件,配置文件可以不同,但相同type的
detector等文件是相同的,毕竟已经将数据和实现完全的分离了。

注意:无论训练/检测,都会build DETECTORS;

builder.py文件

builder文件较为简单,因为train.py中,只出现了

build_detector()
,所以我们先记住里面的两个方法:
build_detector
build()

  • build_detector
    :是创建一个detector,方法里调用了build()方法(所有的build_xx都是直接调用build方法,所以看懂这一个也就看懂所有了)。
  • build()
    :则是调用的Registry.py文件中的
    build_from_cfg()
    方法,这个方法我们已经在上面讲过了。

import:

# -*- coding: utf-8 -*-
from torch import nn
from mmdet.utils import build_from_cfg
#此处不会在执行registry而是直接进行sys.modules查询得到
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS)
#上面的registry是在models文件夹下,registry类的具体实现是在mmdet/utils文件夹下

只需要看一下

build()
的两个参数:
cfg, registry

build_detector()
在train.py中的调用,我们就可以知道,
cfg
是py配置文件中的字典, 以
registry
DETECTORS
为例,cfg就是model字典 (后面注册表为BACKBONES、NECKS等时,就是配置文件中的其他的字典了,不是model) 。

build()方法中,主干是一个判断结构,其实就是判断传进来的cfg是字典列表还是单独的字典,来分情况处理。(以注册表

DETECTORS
为例,是一个单独的字典)

  • 字典列表的话:挨个调用
    build_from_cfg()
    ,将其加到注册表
    ******
    _module_dict
    中,然后再返回
    return nn.Sequential(*modules)
    ,这个地方的作用,有待博主继续研究一下下???
  • 字典的话:直接调用
    build_from_cfg()
    ,将其添加到注册表
    DETECTORS
    中(以DETECTORS为例)。
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
#build_from_cfg()返回值是一个带形参的类,返回时也就完成了实例化的过程。
]
#所以modules就是一个class类的列表
return nn.Sequential(*modules)
#nn.Sequential 一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数

else:
return build_from_cfg(cfg, registry, default_args) #Config dict

def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
#DETECTORS = Registry('detector'),创建一个名为DETECTORS的注册表Registry。

def build_backbone(cfg):
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)

后面的几个

build_XXXXX()
的方法也就跟build_detector()相同咯。

还是以注册表

DETECTORS
为例,配置文件为
cascade_rcnn_r50_fpn_1x.py
来讲解:在model文件夹下的cascade_rcnn.py文件中,有类Cascade_RCNN()的定义,在配置文件中,对应的key被传入类中当做属性,这些属性被初始化的时候,调用对应的
build_XXXXX()
,由此创建它们对应的注册表。
再以
NECK
为例
,调用
build_neck(cfg)
;然后执行
build(cfg, NECKS)
,这一步,形参用到
NECKS
,所以在Registry中,又多了一个名为NECKS的注册表了。然后将配置文件中,字典名为
neck
的,然后生成一个类(类名是neck字典中的type的值,该类在model/necks文件夹下),同时将该类添加到了注册表NECKS的_module_dict中。

#在model/detectors/cascade_rcnn.py中
if neck is not None:
self.neck = builder.build_neck(neck)
#再builder.py中
def build_neck(cfg):
return build(cfg, NECKS)
#在configs/cascade_rcnn_r50_fpn_1x.py中
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
4000
,

到这,NECK的注册和数据读入,相信大家已经很清楚了,其他的注册表也是类似的。

总结:

搭建模型思路:

  • 首先,创建一个名为
    DETECTORS
    的注册表
    Registry
    。这个注册表有属性
    name='detector'
    ,和属性
    _module_dict
    。_module_dict 是一个字典,专门用来存各个对象名和对应的对象。
  • 其次,读取py配置文件,py配置文件是个字典,(字典里还有字典,这里面的字典,也是后面来创建模型的,道理是一样的)。根据key为’type’的字典,创建module,对于的value为其module名,然后再model文件夹下中,已经存在了这些module的类。将字典中的其他数据,作为形参,实例化这些类。并保存这些module到属性_module_dict中。
  • 到这,配置文件的数据,里面的字典(含有type的字典)对应着一个类,type为类名,其他字段则为其属性(其他字段也可能是个字典,后面也有可能要再为它们搭建模型哦)。由此完成模型的搭建。

这是搭建模型的一个思路,虽然讲得篇幅很大,有点乱乱的感觉,但是看懂后,就会发现很简单。

mmdetection搭建模型用途:
mmdetection将配置文件中,字典名为:backbone、neck、roi_extractor、shared_head、head、loss、detector的字典,全部实例化成注册表(Registry),然后这些字典里的type,都被实例化成对应的类(module),并添加到注册表的属性_module_dict中,其他的字段,则为这个类的属性,由此完成模型的建立,实际上,就是将配置文件的字典数据保存到类(module)中,以便后面读取数据,加载数据。

接下来,请看博主的下面的文章:

  • mmdetection源码笔记(二):创建网络模型之cascade_rcnn.py的解读(中)
  • mmdetection源码笔记(二):cascade_rcnn.py搭建模型过程中各个module的forward()的代码解读(下)(待完成)
  • mmdetection源码笔记(三):创建数据集模型之datasets/builder.py的解读(待完成)
  • mmdetection源码笔记(四):训练模型之train_detector()的解读(待完成)
  • mmdetection源码笔记(五):测试之test()部分的解读(待完成)

注:因为有好几个py文件,博主也是按照自己的理解,尽量讲得通俗易懂,如果有不理解的地方,底下评论。如果有错误的地方,还请指出学习,感激不尽。

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: