fairmot代码解析(自己在看代码做笔记用)

本文详细解析了FairMOT框架中的关键代码部分,包括视频加载、模型推理(网络前向传播、目标检测与关联)、模型定义(如DLASeg网络结构)、非极大值抑制算法、以及目标跟踪器的管理机制。此外,还涉及模型训练过程中的heads设置和模型加载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.加载视频

dataloader 是在jde.py函数中定义,用于加载视频,以及进行图像增强。

import datasets.dataset.jde as datasets
dataloader = datasets.LoadVideo(opt.input_video, opt.img_size)

其中LoadVideo如下

class LoadVideo:  # for inference
    def __init__(self, path, img_size=(1088, 608)):
        self.cap = cv2.VideoCapture(path)
        self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
        self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.vh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

        self.width = img_size[0]
        self.height = img_size[1]
        self.count = 0

        self.w, self.h = 1920, 1080
        print('Lenth of the video: {:d} frames'.format(self.vn))

    def get_size(self, vw, vh, dw, dh):
        wa, ha = float(dw) / vw, float(dh) / vh
        a = min(wa, ha)
        return int(vw * a), int(vh * a)

    def __iter__(self):
        self.count = -1
        return self

    def __next__(self):
        self.count += 1
        if self.count == len(self):
            raise StopIteration
        # Read image
        res, img0 = self.cap.read()  # BGR
        assert img0 is not None, 'Failed to load frame {:d}'.format(self.count)
        img0 = cv2.resize(img0, (self.w, self.h))

        # Padded resize
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)

        # Normalize RGB
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0

        # cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1])  # save letterbox image
        return self.count, img, img0

    def __len__(self):
        return self.vn  # number of files


链接
长方形
圆角长方形
菱形
E

2.模型推理

track.py里定义了eval_seq()函数为整个跟踪过程
在multitracker.py函数里定义了JDETracker类,其中包含了update()函数

track,py
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30, use_cuda=True):
  			  tracker = JDETracker(opt, frame_rate=frame_rate)
multitracker.py
class JDETracker(object):
	      ###
	      def update(self, im_blob, img0):

所以跟踪的核心在于update函数:
下面介绍update函数的4个步骤:

2.1 步骤一 :Network forward, get detections & embeddings

2.1.1模型的定义

模型的定义如下

class JDETracker(object):
        self.model = create_model(opt.arch, opt.heads, opt.head_conv)
        self.model = load_model(self.model, opt.load_model)
        self.model = self.model.to(opt.device)
        self.model.eval()

其中create_model和 load_mode都定义在src/lib/models/model.py中

    self.parser.add_argument('--load_model', default='/home/dazhao/code/mot/fairmot/FairMOT-master/models/fairmot_dla34.pth',
                             help='path to pretrained model')
 def update(self, im_blob, img0):
 				 with torch.no_grad():
				            output = self.model(im_blob)[-1]
				            hm = output['hm'].sigmoid_()
				            wh = output['wh']
				            id_feature = output['id']
				            id_feature = F.normalize(id_feature, dim=1)
				
				            reg = output['reg'] if self.opt.reg_offset else None
				            dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
				            id_feature = _tranpose_and_gather_feat(id_feature, inds)
				            id_feature = id_feature.squeeze(0)
				            id_feature = id_feature.cpu().numpy()
				
				        dets = self.post_process(dets, meta)
				        dets = self.merge_outputs([dets])[1]
				
				        remain_inds = dets[:, 4] > self.opt.conf_thres
				        dets = dets[remain_inds]
				        id_feature = id_feature[remain_inds]

其中 output = self.model(im_blob)[-1]是调用网络输出detection和Re-id结果:

在这里插入图片描述
具体包括中心点heatmap图分支、中心点offset分支、目标大小分支。
hm:heatmap图分支包含C个通道,每一个通道包含一个类别,同一个类别的目标在一个heatmap图上,官方代码是只检测人,所以c=1。
reg:判断是前景还是后景。
wh:目标大小分支用来预测目标矩形框的w与h偏差值。
id:则是reid形成的特征,厚度为128

2.1.1 极大值抑制以及最大筛选等手段

dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)

mot_decode函数在src/lib/models/decode.py中
其中decode.py中定义了_nms(),_topk_channel(),_topk(),mot_decode()三个函数:

(1)_nms(heat, kernel=3): 这个函数实现了非极大值抑制(Non-Maximum Suppression,NMS)操作,用于在热图上执行NMS,以去除冗余的检测框。
(2)_topk_channel(scores, K=40): 这个函数用于从得分图中获取前K个最高分数的像素位置,并返回它们的坐标。
(3)_topk(scores, K=40): 这个函数类似于 _topk_channel,但是除了返回坐标外,还返回了检测框的得分和类别。
(4)mot_decode(heat, wh, reg=None, ltrb=False, K=100): 这个函数是整个目标检测的后处理函数,它调用了上述的 _nms、_topk_channel、_topk 函数,从模型输出的热图、宽度和高度预测中解码出目标框的位置和相关信息,并返回检测结果。

2.1.2 多目标跟踪器中用于管理已跟踪的目标轨迹的部分。

unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)

具体来说:

unconfirmed: 是一个列表,用于存储尚未激活(未确认)的目标轨迹。这些轨迹可能是在前几帧出现但尚未被确认的目标,或者是在丢失状态后重新激活的目标。

tracked_stracks: 是一个列表,用于存储已经被激活的目标轨迹。这些轨迹是在当前帧中被成功跟踪的目标,具有已知的标识符和状态。

在这段代码中,遍历了多目标跟踪器中已跟踪的目标轨迹列表 self.tracked_stracks。对于每个轨迹,如果它尚未被激活(即未确认),则将其添加到 unconfirmed 列表中;如果已经被激活,则将其添加到 tracked_stracks 列表中。

2.1这部分输出的是检测结果以及生命周期 detections = [] unconfirmed=[] tracked_stracks=[] ,其中detections不仅包含了目标框,还包含了其外观特征

2.2 步骤2:目标关联

3.训练

3.1输出的heads确定

    opt = opts().update_dataset_info_and_set_heads(opt, dataset)

输出:opt.heads {‘hm’: 1, ‘wh’: 4, ‘id’: 128, ‘reg’: 2}

3.2加载模型

 model = create_model(opt.arch, opt.heads, opt.head_conv) 

models文件夹model.py里定义了create_model()函数

from .networks.pose_dla_dcn import get_pose_net as get_dla_dcn

_model_factory = {
  'dlav0': get_dlav0, # default DLAup
  'dla': get_dla_dcn,
  'dlaconv': get_dla_conv,
  'resdcn': get_pose_net_dcn,
  'resfpndcn': get_pose_net_fpn_dcn,
  'hrnet': get_pose_net_hrnet,
  'yolo': get_pose_net_yolo
}

def create_model(arch, heads, head_conv):
  num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
  arch = arch[:arch.find('_')] if '_' in arch else arch
  get_model = _model_factory[arch]
  model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
  return model

其中get_pose_net()在networks文件夹的pose_dla_dcn.py文件夹里

pose_dla_dcn.py主要有下面几个模块,get_pose_net()是整的分割模块,调用了下面几个模块
self.base 是一个基础模型,通过名称 base_name 初始化。channels 是基础模型的通道数量列表。
self.dla_up 是一个 DLAUp 模块,用于上采样。
self.ida_up 是一个 IDAUp 模块,用于级联融合特征图。
self.base 是一个基础模型,通过名称 base_name 初始化。
self.heads 是一个字典,包含了模型的输出头。

函数为:

def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
  model = DLASeg('dla{}'.format(num_layers), heads,
                 pretrained=True,
                 down_ratio=down_ratio,
                 final_kernel=1,
                 last_level=5,
                 head_conv=head_conv)
  return model

其中DLASeg定义了整个网络结构:

class DLASeg(nn.Module):
    def __init__(self, base_name, heads, pretrained, down_ratio, final_kernel,
                 last_level, head_conv, out_channel=0):
        super(DLASeg, self).__init__()
        assert down_ratio in [2, 4, 8, 16]
        self.first_level = int(np.log2(down_ratio))
        self.last_level = last_level
        self.base = globals()[base_name](pretrained=pretrained)
        channels = self.base.channels
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)

        if out_channel == 0:
            out_channel = channels[self.first_level]

        self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], 
                            [2 ** i for i in range(self.last_level - self.first_level)])
        
        self.heads = heads
        for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
              fc = nn.Sequential(
                  nn.Conv2d(channels[self.first_level], head_conv,
                    kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(head_conv, classes, 
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))
              if 'hm' in head:
                fc[-1].bias.data.fill_(-2.19)
              else:
                fill_fc_weights(fc)
            else:
              fc = nn.Conv2d(channels[self.first_level], classes, 
                  kernel_size=final_kernel, stride=1, 
                  padding=final_kernel // 2, bias=True)
              if 'hm' in head:
                fc.bias.data.fill_(-2.19)
              else:
                fill_fc_weights(fc)
            self.__setattr__(head, fc)

    def forward(self, x):
        x = self.base(x)
        x = self.dla_up(x)

        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))

        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(y[-1])
        return [z]
    

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大泽泽的小可爱

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值