CenterNet代码分析

CenterNet源码分析

侯德柱
CenterNet源代码

概述

  1. src

    1. src/main.py主函数

    2. src/demo.py展示测试结果图片

    3. src/test.py测试数据

  2. src/lib

    1. src/lib/opts.py解析参数
    2. src/lib/llogger.py
    3. src/lib/datasets数据集文件夹
      1. src/lib/datasets/dataset实现加载哪个数据集,其每一个子类为数据集名字
      2. src/lib/datasets/sample如何对数据采样,针对不同的网络返回什么信息,每一个子类为对应的网络名字
    4. src/lib/detectors检测方法集文件夹,包括如何训练,如何检测参数
    5. src/lib/external实现nms?
    6. src/lib/models基本网络结构,如dla、resnet、hourglass等
    7. src/lib/trains实现trainer,网络结构构建
  3. src/tools数据集预处理 ## 命令行参数

    1
    opt.num_stacks = 2 if opt.arch == 'hourglass' else 1
    ## 简称

  • res : resolution
  • crop裁剪
  • reg_loss : regression loss

自定义数据集

  1. 复制一个新建一个子类,继承data.Dataset
  2. 计算所有的图片的均值和标准差,并修改class.mean、class.std
  3. class.num_class改为自己数据集的类别数
  4. default_resolution为默认的分辨率
  5. 修改读取json文件的路径
  6. 修改self.class_name和self._valid_ids
  7. 将数据集加入src/lib/datasets/dataset_factory
  8. 在/src/lib/opts.py文件中修改self.parser.add_argument('--dataset'
  9. 修改src/lib/utils/debugger.py
    1
    2
    3
    4
    5
    #(1)第45行下方加入两行:
    elif num_classes == 6 or dataset == 'food':
    self.names = food_class_name
    # (2)第460行下方加入自己所定义的类别,不包含背景:
    food_class_name = ['aaa', 'bbb', 'ccc', 'ddd', 'eee', 'fff']

自定义网络

src/lib/dataset/sample

重写__getItem__

自定义基础网络结构src/lib/model

继承nn.Module,实现forward。不包括loss

src/lib/models

model.py

  • create_model

    1
    def create_model(arch: str, heads: Dict[str, int], head_conv: int) -> nn.Module:

    所有基本网络结构的get_model参数与本函数一致

    • 参数

      • arch网络架构 命令行参数--arch确定

      • heads opt.heads确定,key为特征名字,value为值

        1
        2
        3
        4
        5
        6
        elif opt.task == 'ctdet':
        # assert opt.dataset in ['pascal', 'coco']
        opt.heads = {'hm': opt.num_classes,
        'wh': 2 if not opt.cat_spec_wh else 2 * opt.num_classes}
        if opt.reg_offset:
        opt.heads.update({'reg': 2})
      • head_conv 命令行参数--head_conv 确定 output head的卷积层通道数 0 没有卷积层,-1默认:resnet64,dla256

    • 返回值

      • 基本网络架构

networks

forward返回值

resnet_dcn.py

PoseResNet
1
class PoseResNet(nn.Module):
  • __init__

    1
    def __init__(self, block: Type[BasicBlock], layers: List[int, int, int, int], heads: Dict[str, int], head_conv: int):
    • 成员变量

      • self.head是字典,key为分支名称(hm,wh,reg),value为最后输出的通道数=数据集类别数

      • 构造函数最后根据head的key,用__setattr__把key设成属性名,成员变量的值是卷积层的model,输入256通道,输出数据集类别数通道

  • forward

    1
    def forward(self, x: torch.Tensor) -> List[Dict[str, torch.Tensor]]
    • 返回值中每一个元素是head不同key的全连接层对应前向传播的结果,list的长度为1
      1. hm:热力图size=(1,2,128,128)为 (batch size ,num of classes, H/R,W/R)
      2. wh:宽高size=(1,2,128,128)为 (batch size ,2, H/R,W/R)
      3. reg:回归误差size=(1,2,128,128)为 (batch size ,2, H/R,W/R)

losses

RegL1Loss

1
class RegL1Loss(nn.Module):
  • __init__
  • forward
    • 中间变量
      • pred size=(1,128,2)
      • mask size=(1,128,2)

src/lib/trains

ModelWithLoss

1
2
3
4
5
6
7
8
9
10
class ModelWithLoss(torch.nn.Module):
def __init__(self, model: torch.nn.Module, loss: torch.nn.Module):
super(ModelWithLoss, self).__init__()
self.model = model
self.loss = loss

def forward(self, batch: dict) -> (Dict[str, torch.Tensor], torch.Tensor, Dict[str, torch.Tensor]):
outputs = self.model(batch['input'])
loss, loss_stats = self.loss(outputs, batch)
return outputs[-1], loss, loss_stats

outputs为network的输出,List[Dict[str, torch.Tensor]]

model为网络前向传播的输出,不包括loss

返回值

  • loss Tensor标量
  • loss_stats

BaseTrainer

  • __init__

    1
    def __init__(self, opt: Namespace, model: torch.nn.Module, optimizer: Optional[torch.optim.Adam] = None):
    • opt为命令行参数
    • model为src/lib/model中定义的model
    • optimizer为优化器,main.py中用的adam
    • self.loss_stats, self.loss = self._get_losses(opt)
    • self.model_with_loss = ModelWithLoss
  • run_epoch

    1
    def run_epoch(self, phase: str, epoch: int, data_loader: torch.utils.data.DataLoader)
    • 主要功能:
      1. 对model_with_loss正向传播和反向传播,使用optimizer优化
      2. 计时
      3. 每5轮验证一次
    • 参数
      • phase为'train'或'val'
    • 中间变量
      • batch是dataloader返回的字典
    • 需要用到的成员函数
      • opt
        • task
        • exp_id
      • model_with_loss
  • debug

    1
    def debug(self, batch: dict, output: List[Dict[str, torch.Tensor]], iter_id: int):

    需自己实现

  • save_result

    1
    def save_result(self, output: Dict[str, torch.Tensor], batch: dict, results: dict):

    需自己实现

    run_epoch中results传入的是空,results作为返回值

  • _get_loss

    1
    def _get_losses(self, opt: Namespace) -> (List[str], torch.nn.Module):

    需自己实现

    • 返回值
      • List[str]是损失分量的名字
  • train & val

    分别调用run_epoch("train") run_epoch("val")

src/lib/detectors

该类为检测类,使用训练好的模型检测

  • pre_process

    1
    pre_process(self, image: np.ndarray, scale: float, meta: Optional[list] = None) -> (np.ndarray, dict)
    预处理,默认是实现图片放缩,meta在除ddd之外的检测器没用

  • process

    1
    process(self, images: torch.Tensor, return_time=False)-> Tuple[dict, torch.Tensor, float]:
    必须自己实现

    第一个dict为检测结果output,第二个Tensor为检测结果的bbox和置信度等数据,需要自己实现[网络名]_decode函数在src/lib/models/decode.py中,最后一个参数为检测开始时间,如果return_time=False则为0

  • post_process

    1
    post_process(self, dets: torch.Tensor, meta: dict, scale: float = 1) -> Dict[int, Iterable]
    必须自己实现

    反缩放,dets经[网络名]_post_process函数处理后返回检测框的信息(bbox,概率,类别)?

  • merge_outputs

    1
    merge_outputs(self, detections: List[Dict[int, Iterable]]) -> Dict[int,np.ndarray]
    必须自己实现

  • debug

    1
    debug(self, debugger: Debugger, images: np.ndarray, dets: torch.Tensor, output, scale: float = 1)
    dets为detections的简写

  • show_results

    1
    show_results(self, debugger: Debugger, image: np.ndarray, results: Dict[int, np.ndarray])
    显示测试结果的图片

  • 以上这些函数会在基类中run函数调用

    1
    run(self, image_or_path_or_tensor: Union[str, np.ndarray, Dict[str, List[torch.Tensor]]], meta:Optional[list]=None):

src/lib/utils

src/lib/utils/debugger.py

用于展示图片,

src/lib/utils/post_process.py

自定义检测完处理函数,整理检测结果

  • ctdet_post_process

    1
    ctdet_post_process(dets: np.ndarray, c: Union[list, np.ndarray], s: Union[list, np.ndarray],h: int, w: int, num_classes: int) -> List[Dict[int, Iterable]]:

dets: batch x max_dets x dim 批大小×最多检测出多少个×维度。在ctdet中,dim的大小为6,前4个参数为x1,y1,x2,y2;第5个参数为概率;第6个参数为类别

返回中的dict的key为类别,Iterable为np.nparray

src/lib/utils/image.py

图像预处理

  • transform_preds

    1
    2
    transform_preds(coords: np.ndarray, center: np.ndarray,
    scale: Union[np.ndarray, list, int], output_size: Sequence)

    对点进行仿射变换

    coords为坐标点的矩阵[[x0,y0],[x1,y1],...],center中心,scale放大倍数,rot旋转,output_size输出尺寸为长度2的向量

  • get_affine_transform

    1
    2
    3
    4
    5
    6
    get_affine_transform(center: np.ndarray,
    scale: Union[np.ndarray, list, int],
    rot: float,
    output_size: Sequence,
    shift=np.array([0, 0], dtype=np.float32),
    inv=0) -> np.ndarray:

    得到仿射变换矩阵,center中心,scale放大倍数,rot旋转,output_size输出尺寸为长度2的向量,

  • affine_transform

    1
    affine_transform(pt: np.ndarray, t: np.ndarray)

    pt长度为2的向量,t仿射变换矩阵

  • draw_msra_gaussian draw_umich_gaussian

    1
    draw_umich_gaussian(heatmap, center, radius, k=1)

    TODO:

  • color_aug

    1
    color_aug(data_rng:int, image:np.ndarray, eig_val:np.ndarray, eig_vec:np.ndarray) -> None

    图片增强,直接在原图像上修改

    • data_rng为随机数种子,默认np.random.RandomState(123)
    • image归一化的图片
    • eig_val default = np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32)
    • eig_vec default = np.array([ [-0.58752847, -0.69563484, 0.41340352], [-0.5832747, 0.00994535, -0.81221408], [-0.56089297, 0.71832671, 0.41158938] ], dtype=np.float32)
  • draw_dense_reg

    1
    draw_dense_reg(regmap, heatmap, center, value, radius, is_offset=False)

COCOAPI

  • COCO.loadImgs

    1
    loadImgs(self, ids:Union[int,Sequence]=[])->List[dict]

    返回图片文件名等信息

    e.g.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    [
    {
    "license": 1,
    "file_name": "000000324158.jpg",
    "coco_url": "http://images.cocodataset.org/val2017/000000324158.jpg",
    "height": 334,
    "width": 500,
    "date_captured": "2013-11-19 23:54:06",
    "flickr_url": "http://farm1.staticflickr.com/169/417836491_5bf8762150_z.jpg",
    "id": 324158
    },
    {
    ...
    }
    ]
  • COCO.getAnnIds

    1
    getAnnIds(self, imgIds:List[int]=[], catIds:List[int]=[], areaRng:List[float]=[], iscrowd:bool=None)->List[int]

    e.g.

    1
    [10673, 345846, 349287, 351168, 353731, 359808, 638724, 1192678, 1194317, 1341598, 1346450, 1348862, 2038341, 2044744, 2162813]
  • COCO.loadAnns

    1
    loadAnns(self, ids=[])->List[dict]

    加载标记,

    e.g.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    [
    {
    'segmentation': [[
    44.2, 167.74, 48.39, 162.71, 66.87, 161.19, 72.75, 162.71, 80.98, 164.22, 80.14, 168.92, 78.12, 168.58, 72.24, 170.94, 71.74, 173.79, 54.27, 174.97, 52.59, 170.1, 50.75, 167.91, 45.71, 168.92, 45.2, 167.58
    ]],
    'area': 331.9790999999998,
    'iscrowd': 0,
    'image_id': 324158,
    'bbox': [44.2, 161.19, 36.78, 13.78],
    'category_id': 3,
    'id': 345846
    },
    {
    ...
    }
    ]

ctdet

sample/ctdet.py

CTDetDataset.__getitem__

设原图片维度为W*H*3

  • CTDetDataset.__getitem__返回值

    keyvalue备注
    'input'np.ndarray输入图片shape=[3x512x512]
    'hm'np.ndarrayheatmap size (num_classes, output_h, output_w), (80x128x128)(dtype=np.float32 正态分布,中心为bbox中心,w和h为bbox的w和h
    'reg_mask'np.ndarraydtype=uint8 shape=[128] [1 1 1 1 1 1 0 0 ...]
    'ind'dtype=int64 shape=[128] mask为1的地方有值
    'wh'dtyte=float32 shape=[128x2] mask为1的地方有值
    'dense_wh'
    'dense_wh_mask'
    'cat_spec_wh'
    'cat_spec_mask'
    'reg'dtyte=float32 shape=[128x2] mask为1的地方有值,其他为0
    'meta'dict见下方,非训练或debug时才有
    • meta

      keyvalue type备注
      'c'np.ndarray仿射变换的center,(x,y)长度为2的向量,dtype=np.float32,默认为图片中心点
      's'float仿射变换的scale
      'gt_det'
      'img_id'
  • 变量

    变量名type备注
    inpnp.ndarrayinput输入图片,80行transpose(2,0,1)将其变为3*W*H
    num_objsint图像中标记的矩形框数量和self.max_objs最小值
    opt.no_color_augbooldefault=False 进行色彩增强,默认增强
    opt.keep_resbooldefault=False保持分辨率
    opt.flipfloat图片翻转概率
    opt.down_ratiointdefault=4 下采样率
    opt.dense_whbooldefault=False 中心点附件加权回归或中心点回归
    • input image(inp)

      • 39行cv2.imread
      • 68行opt.flip按概率翻转
      • 73行仿射变换
      • 76行归一化(÷255)
      • 79行变正态分布
      • 80行transpose(2,0,1)将其维度变为3*W*H
      • 最后作为res['inp']返回
    • bbox为目标矩形框

      • 102行转为x1,y1,x2,y2
      • 105行根据图像翻转
      • 106行仿射变换
      • 108行裁剪使其不超过图片的边界
      • bbox用来生成heatmap,返回值中没有bbox
    • heatmap

      hm维度为num_classes×output_h×output_w

      对于coco,num_classes=80。output_h×output_w 默认为输入图片的高宽分别除以4

      dtype=np.float32

      如果bbox中没有对应的类别,则该类的heatmap为0

      由image.py中gaussian_radius和

其他函数

  • CTDetDataset._coco_box_to_bbox

    1
    _coco_box_to_bbox(self, box:Sequence)->np.ndarray

    将(x,y.w,h)变为(x1,y1,x2,y2),左上和右下的坐标

trains/ctdet.py

CtdetLoss

1
class CtdetLoss(torch.nn.Module):
  • __init__

    1
    def __init__(self, opt):
    • 成员变量
      • self.crit是一个heatamp损失函数,返回值为torch.Tensor标量
      • self.crit_hw是长宽损失函数
      • self.crit_reg是修正误差损失函数
  • forward

    1
    def forward(self, batch: dict) -> (object, torch.Tensor, Dict[str, torch.Tensor]):
    • 参数
      • batch是dataloader __getitem__的返回值
    • 返回值
      • outputs是model/network中前向传播的返回值
      • loss是标量torch.Tensor
      • loss_stats是字典,损失分量的名字为key和其数值的Tensor为value

CtdetTrainer:

1
class CtdetTrainer(BaseTrainer):

test.py

输出

Bar.suffix

[当前轮数/总轮数]|Tot:总时间|ETA:剩余时间|tot总时间|load载入时间|pre预处理时间|net|dec|post|merge|

0%