diff --git a/supervision/detection/line_counter.py b/supervision/detection/line_counter.py index ea97fcb83..1c1fe7709 100644 --- a/supervision/detection/line_counter.py +++ b/supervision/detection/line_counter.py @@ -26,6 +26,7 @@ def __init__(self, start: Point, end: Point): self.tracker_state: Dict[str, bool] = {} self.in_count: int = 0 self.out_count: int = 0 + self.crossed = [] def trigger(self, detections: Detections): """ @@ -34,8 +35,13 @@ def trigger(self, detections: Detections): Attributes: detections (Detections): The detections for which to update the counts. + Returns: + np.ndarray: A boolean array indicating + which detection has crossed the line on the either sides """ - for xyxy, _, confidence, class_id, tracker_id in detections: + self.crossed = [False] * len(detections) + + for i, (xyxy, _, confidence, class_id, tracker_id) in enumerate(detections): # handle detections with no tracker_id if tracker_id is None: continue @@ -67,8 +73,13 @@ def trigger(self, detections: Detections): self.tracker_state[tracker_id] = tracker_state if tracker_state: self.in_count += 1 + self.crossed[i] = True + else: self.out_count += 1 + self.crossed[i] = True + + return self.crossed class LineZoneAnnotator: