Skip to content

Commit

Permalink
Merge pull request #1528 from roboflow/feature/make-BaseTrack-count-i…
Browse files Browse the repository at this point in the history
…nstance-var

let ByteTrack to maintain track ID per instance
  • Loading branch information
LinasKo authored Oct 17, 2024
2 parents f60d89f + 4a7d9ce commit 251b952
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 61 deletions.
21 changes: 6 additions & 15 deletions supervision/tracker/byte_tracker/basetrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ class TrackState(Enum):


class BaseTrack:
_count = 0

def __init__(self):
self.track_id = 0
self.is_activated = False
Expand All @@ -34,20 +32,13 @@ def __init__(self):
def end_frame(self) -> int:
return self.frame_id

@staticmethod
def next_id() -> int:
BaseTrack._count += 1
return BaseTrack._count

@staticmethod
def reset_counter():
BaseTrack._count = 0
BaseTrack.track_id = 0
BaseTrack.start_frame = 0
BaseTrack.frame_id = 0
BaseTrack.time_since_update = 0

This comment has been minimized.

Copy link
@LinasKo

LinasKo Oct 17, 2024

Author Collaborator

Leaving a comment for the future. All of these except _count are class variables.
Resetting these will not impact any instances.

def reset_counter(self):
self.track_id = 0
self.start_frame = 0
self.frame_id = 0
self.time_since_update = 0

def activate(self, *args):
def activate(self, *args, **kwargs):
raise NotImplementedError

def predict(self):
Expand Down
112 changes: 67 additions & 45 deletions supervision/tracker/byte_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,35 @@
from supervision.tracker.byte_tracker.kalman_filter import KalmanFilter


class IdCounter:
def __init__(self):
self.reset()

def reset(self) -> None:
self._id = self.NO_ID

def new_id(self) -> int:
self._id += 1
return self._id

@property
def NO_ID(self) -> int:
return 0


class STrack(BaseTrack):
shared_kalman = KalmanFilter()
_external_count = 0

def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames):
def __init__(
self,
tlwh,
score,
class_ids,
minimum_consecutive_frames,
internal_id_counter: IdCounter,
external_id_counter: IdCounter,
):
super().__init__()
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
Expand All @@ -24,10 +48,13 @@ def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames):
self.class_ids = class_ids
self.tracklet_len = 0

self.external_track_id = -1

self.minimum_consecutive_frames = minimum_consecutive_frames

self.internal_id_counter = internal_id_counter
self.external_id_counter = external_id_counter
self.internal_track_id = self.internal_id_counter.NO_ID
self.external_track_id = self.external_id_counter.NO_ID

def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
Expand Down Expand Up @@ -57,7 +84,7 @@ def multi_predict(stracks):
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.internal_track_id = self.next_id()
self.internal_track_id = self.internal_id_counter.new_id()
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh)
)
Expand All @@ -68,21 +95,19 @@ def activate(self, kalman_filter, frame_id):
self.is_activated = True

if self.minimum_consecutive_frames == 1:
self.external_track_id = self.next_external_id()
self.external_track_id = self.external_id_counter.new_id()

self.frame_id = frame_id
self.start_frame = frame_id

def re_activate(self, new_track, frame_id, new_id=False):
def re_activate(self, new_track, frame_id):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
)
self.tracklet_len = 0
self.state = TrackState.Tracked

self.frame_id = frame_id
if new_id:
self.internal_track_id = self.next_id()
self.score = new_track.score

def update(self, new_track, frame_id):
Expand All @@ -103,8 +128,8 @@ def update(self, new_track, frame_id):
self.state = TrackState.Tracked
if self.tracklet_len == self.minimum_consecutive_frames:
self.is_activated = True
if self.external_track_id == -1:
self.external_track_id = self.next_external_id()
if self.external_track_id == self.external_id_counter.NO_ID:
self.external_track_id = self.external_id_counter.new_id()

self.score = new_track.score

Expand Down Expand Up @@ -142,15 +167,6 @@ def tlwh_to_xyah(tlwh):
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)

@staticmethod
def next_external_id():
STrack._external_count += 1
return STrack._external_count

@staticmethod
def reset_external_counter():
STrack._external_count = 0

@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
Expand All @@ -169,24 +185,6 @@ def __repr__(self):
)


def detections2boxes(detections: Detections) -> np.ndarray:
"""
Convert Supervision Detections to numpy tensors for further computation.
Args:
detections (Detections): Detections/Targets in the format of sv.Detections.
Returns:
(np.ndarray): Detections as numpy tensors as in
`(x_min, y_min, x_max, y_max, confidence, class_id)` order.
"""
return np.hstack(
(
detections.xyxy,
detections.confidence[:, np.newaxis],
detections.class_id[:, np.newaxis],
)
)


class ByteTrack:
"""
Initialize the ByteTrack object.
Expand Down Expand Up @@ -235,6 +233,9 @@ def __init__(
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []

self.internal_id_counter = IdCounter()
self.external_id_counter = IdCounter()

def update_with_detections(self, detections: Detections) -> Detections:
"""
Updates the tracker with the provided detections and returns the updated
Expand Down Expand Up @@ -275,7 +276,13 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
```
"""

tensors = detections2boxes(detections=detections)
tensors = np.hstack(
(
detections.xyxy,
detections.confidence[:, np.newaxis],
detections.class_id[:, np.newaxis],
)
)
tracks = self.update_with_tensors(tensors=tensors)

if len(tracks) > 0:
Expand Down Expand Up @@ -311,11 +318,12 @@ def reset(self):
ensuring the tracker starts with a clean state for each new video.
"""
self.frame_id = 0
BaseTrack.reset_counter()
self.internal_id_counter.reset()
self.external_id_counter.reset()
self.tracked_tracks: List[STrack] = []
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []
BaseTrack.reset_counter()
STrack.reset_external_counter()

def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
"""
Expand Down Expand Up @@ -353,7 +361,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
if len(dets) > 0:
"""Detections"""
detections = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
STrack(
STrack.tlbr_to_tlwh(tlbr),
s,
c,
self.minimum_consecutive_frames,
self.internal_id_counter,
self.external_id_counter,
)
for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep)
]
else:
Expand Down Expand Up @@ -387,15 +402,22 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
track.re_activate(det, self.frame_id)
refind_stracks.append(track)

""" Step 3: Second association, with low score detection boxes"""
# association the untrack to the low score detections
if len(dets_second) > 0:
"""Detections"""
detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
STrack(
STrack.tlbr_to_tlwh(tlbr),
s,
c,
self.minimum_consecutive_frames,
self.internal_id_counter,
self.external_id_counter,
)
for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second)
]
else:
Expand All @@ -416,7 +438,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
track.re_activate(det, self.frame_id)
refind_stracks.append(track)

for it in u_track:
Expand Down
2 changes: 1 addition & 1 deletion supervision/tracker/byte_tracker/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def indices_to_matches(

def linear_assignment(
cost_matrix: np.ndarray, thresh: float
) -> [np.ndarray, Tuple[int], Tuple[int, int]]:
) -> Tuple[np.ndarray, Tuple[int], Tuple[int, int]]:
if cost_matrix.size == 0:
return (
np.empty((0, 2), dtype=int),
Expand Down
Empty file added test/tracker/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions test/tracker/test_byte_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List

import numpy as np
import pytest

import supervision as sv


@pytest.mark.parametrize(
"detections, expected_results",
[
(
[
sv.Detections(
xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
class_id=np.array([1, 1]),
confidence=np.array([1, 1]),
),
sv.Detections(
xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
class_id=np.array([1, 1]),
confidence=np.array([1, 1]),
),
],
sv.Detections(
xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
class_id=np.array([1, 1]),
confidence=np.array([1, 1]),
tracker_id=np.array([1, 2]),
),
),
],
)
def test_byte_tracker(
detections: List[sv.Detections],
expected_results: sv.Detections,
) -> None:
byte_tracker = sv.ByteTrack()
tracked_detections = [byte_tracker.update_with_detections(d) for d in detections]
assert tracked_detections[-1] == expected_results

0 comments on commit 251b952

Please # to comment.