ByteTrack论文精读

论文名称:ByteTrack: Multi-Object Tracking by Associating Every Detection Box

MOT排行榜: MOT17,MOT20 private赛道No1. Online

论文代码: https://github.com/ifzhang/ByteTrack

发布时间:2021


这篇论文原作者在知乎有文章
ByteTrack: Multi-Object Tracking by Associating Every Detection Box

然后就是另一篇概述的很好的文章
目标跟踪之 MOT 经典算法:ByteTrack 算法原理以及多类别跟踪

这篇文章主要就是把低置信度的检测框也参与匹配,与DeepSORT都是为了减少missing Track的情况。
DeepSORT主要是将长时间未匹配到的跟踪预测框与检测框匹配的权重降低,优先匹配missing track时间短的跟踪预测框。
ByteTrack主要是加入了低置信度的检测框与轨迹进行匹配,重点提升了被遮挡目标仍能保持匹配。

DeepSORT的级联匹配与ByteTrack的匹配

DeepSORT

优先匹配time_since_update小的tracks

Maching Cascade

完整代码见:https://github1s.com/nwojke/deep_sort/blob/HEAD/deep_sort/linear_assignment.py

下面贴一下关键代码部分:

Tip:匹配函数的参数见下方折叠代码

点击查看代码
def matching_cascade(
        distance_metric, max_distance, cascade_depth, tracks, detections,
        track_indices=None, detection_indices=None):
    """Run matching cascade.

    Parameters
    ----------
    distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
        The distance metric is given a list of tracks and detections as well as
        a list of N track indices and M detection indices. The metric should
        return the NxM dimensional cost matrix, where element (i, j) is the
        association cost between the i-th track in the given track indices and
        the j-th detection in the given detection indices.
    max_distance : float
        Gating threshold. Associations with cost larger than this value are
        disregarded.
    cascade_depth: int
        The cascade depth, should be se to the maximum track age.
    tracks : List[track.Track]
        A list of predicted tracks at the current time step.
    detections : List[detection.Detection]
        A list of detections at the current time step.
    track_indices : Optional[List[int]]
        List of track indices that maps rows in `cost_matrix` to tracks in
        `tracks` (see description above). Defaults to all tracks.
    detection_indices : Optional[List[int]]
        List of detection indices that maps columns in `cost_matrix` to
        detections in `detections` (see description above). Defaults to all
        detections.

    Returns
    -------
    (List[(int, int)], List[int], List[int])
        Returns a tuple with the following three entries:
        * A list of matched track and detection indices.
        * A list of unmatched track indices.
        * A list of unmatched detection indices.

    """
    # ...
def matching_cascade(distance_metric, max_distance, cascade_depth,
                     tracks, detections,track_indices=None, detection_indices=None):
  if track_indices is None:
      track_indices = list(range(len(tracks)))
  if detection_indices is None:
      detection_indices = list(range(len(detections)))
  # init
  unmatched_detections = detection_indices
  matches = []
  # cascade_depth 就是track保留的最大值 Age maximum
  for level in range(cascade_depth):
      if len(unmatched_detections) == 0:  # No detections left
          break
      track_indices_l = [
          k for k in track_indices
          # time_since_update 是上次update的次数,优先匹配time_since_update小的tracks
          if tracks[k].time_since_update == 1 + level
      ]
      if len(track_indices_l) == 0:  # Nothing to match at this level
          continue
      matches_l, _, unmatched_detections = min_cost_matching(distance_metric,
            max_distance, tracks, detections, track_indices_l, unmatched_detections)
          matches += matches_l
  unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
  return matches, unmatched_tracks, unmatched_detections

ByteTrack

Step 1: Add newly detected tracklets to tracked_stracks
Step 2: First association, with high score detection boxes ☆☆☆☆☆
Step 3: Second association, with low score detection boxes **☆☆☆☆☆ **
Step 4: Init new stracks
Step 5: Update state

完整代码见:https://github1s.com/ifzhang/ByteTrack/blob/HEAD/yolox/tracker/byte_tracker.py

def update(self, output_results, img_info, img_size):
    self.frame_id += 1
    activated_starcks = []
    refind_stracks = []
    lost_stracks = []
    removed_stracks = []

    if output_results.shape[1] == 5:
        scores = output_results[:, 4]
        bboxes = output_results[:, :4]
    else:
        output_results = output_results.cpu().numpy()
        scores = output_results[:, 4] * output_results[:, 5]
        bboxes = output_results[:, :4]  # x1y1x2y2
    img_h, img_w = img_info[0], img_info[1]
    scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w))
    bboxes /= scale

    remain_inds = scores > self.args.track_thresh
    inds_low = scores > 0.1
    inds_high = scores < self.args.track_thresh

    inds_second = np.logical_and(inds_low, inds_high)
    dets_second = bboxes[inds_second]                   # ☆☆☆☆☆ 第二批置信度低的检测框
    dets = bboxes[remain_inds]                          # ☆☆☆☆☆ 第一批置信度高的检测框
    scores_keep = scores[remain_inds]
    scores_second = scores[inds_second]

    if len(dets) > 0:
        '''Detections'''
        detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
                      (tlbr, s) in zip(dets, scores_keep)]
    else:
        detections = []

    ''' Add newly detected tracklets to tracked_stracks'''
    unconfirmed = []
    tracked_stracks = []  # type: list[STrack]
    for track in self.tracked_stracks:
        if not track.is_activated:
            unconfirmed.append(track)
        else:
            tracked_stracks.append(track)

    ''' Step 2: First association, with high score detection boxes'''
    strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
    # Predict the current location with KF
    STrack.multi_predict(strack_pool)
    dists = matching.iou_distance(strack_pool, detections)
    if not self.args.mot20:
        dists = matching.fuse_score(dists, detections)
    matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)

    for itracked, idet in matches:
        track = strack_pool[itracked]
        det = detections[idet]
        if track.state == TrackState.Tracked:
            track.update(detections[idet], self.frame_id)
            activated_starcks.append(track)
        else:
            track.re_activate(det, self.frame_id, new_id=False)
            refind_stracks.append(track)

    ''' Step 3: Second association, with low score detection boxes'''
    # association the untrack to the low score detections
    if len(dets_second) > 0:
        '''Detections'''
        detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
                      (tlbr, s) in zip(dets_second, scores_second)]
    else:
        detections_second = []
    r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
    dists = matching.iou_distance(r_tracked_stracks, detections_second)
    matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
    for itracked, idet in matches:
        track = r_tracked_stracks[itracked]
        det = detections_second[idet]
        if track.state == TrackState.Tracked:
            track.update(det, self.frame_id)
            activated_starcks.append(track)
        else:
            track.re_activate(det, self.frame_id, new_id=False)
            refind_stracks.append(track)

    for it in u_track:
        track = r_tracked_stracks[it]
        if not track.state == TrackState.Lost:
            track.mark_lost()
            lost_stracks.append(track)

    '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
    detections = [detections[i] for i in u_detection]
    dists = matching.iou_distance(unconfirmed, detections)
    if not self.args.mot20:
        dists = matching.fuse_score(dists, detections)
    matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
    for itracked, idet in matches:
        unconfirmed[itracked].update(detections[idet], self.frame_id)
        activated_starcks.append(unconfirmed[itracked])
    for it in u_unconfirmed:
        track = unconfirmed[it]
        track.mark_removed()
        removed_stracks.append(track)

    """ Step 4: Init new stracks"""
    for inew in u_detection:
        track = detections[inew]
        if track.score < self.det_thresh:
            continue
        track.activate(self.kalman_filter, self.frame_id)
        activated_starcks.append(track)
    """ Step 5: Update state"""
    for track in self.lost_stracks:
        if self.frame_id - track.end_frame > self.max_time_lost:
            track.mark_removed()
            removed_stracks.append(track)

    # print('Ramained match {} s'.format(t4-t3))

    # ...

    return output_stracks
posted @ 2022-04-11 10:01  攻城狮?  阅读(1202)  评论(0编辑  收藏  举报