-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata_classes.py
207 lines (177 loc) · 8.28 KB
/
data_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, NamedTuple, Optional
from mmengine.logging import print_log
from t4_devkit.dataclass import Box3D
class DatasetSplitName(NamedTuple):
"""Represent a pair of dataset and a split name."""
dataset_version: str
split_name: str
@dataclass(frozen=True)
class Detection3DBox:
"""3D boxes from detection."""
box: Box3D
attrs: List[str]
@dataclass(frozen=True)
class SampleData:
"""Dataclass to save data for a sample, for example, 3D bounding boxes."""
sample_token: str
detection_3d_boxes: List[Detection3DBox]
def get_category_attr_counts(
self,
category_name: str,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, int]:
"""
Get total counts of every attribute for the selected category in this scenario.
:param category_name: Selected category name.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {attribute name: total counts}.
"""
category_attr_counts: Dict[str, int] = defaultdict(int)
for detection_3d_box in self.detection_3d_boxes:
box_category_name = detection_3d_box.box.semantic_label.name
if remapping_classes is not None:
# If no category found from the remapping, then it uses the original category name
box_category_name = remapping_classes.get(box_category_name, box_category_name)
if box_category_name == category_name:
for attr_name in detection_3d_box.attrs:
category_attr_counts[attr_name] += 1
return category_attr_counts
def get_category_counts(
self,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, int]:
"""
Get total counts of every category for every sample in this scenario.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {sample token: {category name: total counts}}.
"""
category_counts: Dict[str, int] = defaultdict(int)
for detection_3d_box in self.detection_3d_boxes:
box_category_name = detection_3d_box.box.semantic_label.name
if remapping_classes is not None:
# If no category found from the remapping, then it uses the original category name
box_category_name = remapping_classes.get(box_category_name, box_category_name)
category_counts[box_category_name] += 1
return category_counts
@classmethod
def create_sample_data(
cls,
sample_token: str,
boxes: List[Box3D],
) -> SampleData:
"""
Create a SampleData given the params.
:param sample_token: Sample token to represent a sample (lidar frame).
:param detection_3d_boxes: List of 3D bounding boxes for the given sample token.
"""
detection_3d_boxes = [Detection3DBox(box=box, attrs=box.semantic_label.attributes) for box in boxes]
return SampleData(sample_token=sample_token, detection_3d_boxes=detection_3d_boxes)
@dataclass
class ScenarioData:
"""Data class to save data for a scenario, for example, a list of SampleData."""
scene_token: str
sample_data: Dict[str, SampleData] = field(default_factory=lambda: {}) # Sample token, SampleAnalysis
def add_sample_data(self, sample_data: SampleData) -> None:
"""
Add a SampleData to ScenarioData.
:param sample_data: SampleData contains data for descripting a sample/lidar frame.
"""
if sample_data.sample_token in self.sample_data:
print_log(f"Found {sample_data.sample_token} in the data, replacing it...")
self.sample_data[sample_data.sample_token] = sample_data
def get_scenario_category_counts(
self,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, Dict[str, int]]:
"""
Get total counts of every category for every sample in this scenario.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {sample token: {category name: total counts}}.
"""
scenario_category_counts: Dict[str, Dict[str, int]] = {}
for sample_token, sample_data in self.sample_data.items():
scenario_category_counts[sample_token] = sample_data.get_category_counts(
remapping_classes=remapping_classes
)
return scenario_category_counts
def get_scenario_category_attr_counts(
self,
category_name: str,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, Dict[str, int]]:
"""
Get total counts of every attribute for the selected category in this scenario.
:param category_name: Selected category name.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {sample token: {attribute name: total counts}}.
"""
scenario_category_counts: Dict[str, Dict[str, int]] = {}
for sample_token, sample_data in self.sample_data.items():
scenario_category_counts[sample_token] = sample_data.get_category_attr_counts(
category_name=category_name, remapping_classes=remapping_classes
)
return scenario_category_counts
@dataclass
class AnalysisData:
"""Data class to save data for an analysis, for example, a list of ScenarioData."""
data_root_path: str
dataset_version: str
scenario_data: Dict[str, ScenarioData] = field(default_factory=lambda: {}) # Scene token, ScenarioAnalysis
def add_scenario_data(self, scenario_data: ScenarioData) -> None:
"""
Add a ScenarioData to AnalysisData.
:param scenario_data: ScenarioData contains data for descripting a scenario (more than a
sample/lidar frames).
"""
if scenario_data.scene_token in self.scenario_data:
print_log(f"Found {scenario_data.scene_token} in the data, replacing it...")
self.scenario_data[scenario_data.scene_token] = scenario_data
def aggregate_category_counts(
self,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, int]:
"""
Get total counts of every category in this AnalysisData.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {category name: total counts}.
"""
# {category_name: counts}
total_category_counts = defaultdict(int)
for scenario_data in self.scenario_data.values():
scenario_category_counts: Dict[str, Dict[str, int]] = scenario_data.get_scenario_category_counts(
remapping_classes=remapping_classes
)
for category_counts in scenario_category_counts.values():
for name, counts in category_counts.items():
total_category_counts[name] += counts
return total_category_counts
def aggregate_category_attr_counts(
self,
category_name: str,
remapping_classes: Optional[Dict[str, str]] = None,
) -> Dict[str, int]:
"""
Get total counts of every attribute for the selected category in this AnalysisData.
:param category_name: Selected category name.
:param remapping_classes: Set if we want to aggregate the total counts after remapping
categories.
:return: A dict of {attribute name: total counts}.
"""
# {category_name: counts}
total_category_counts = defaultdict(int)
for scenario_data in self.scenario_data.values():
scenario_category_counts: Dict[str, Dict[str, int]] = scenario_data.get_scenario_category_attr_counts(
category_name=category_name, remapping_classes=remapping_classes
)
for category_counts in scenario_category_counts.values():
for name, counts in category_counts.items():
total_category_counts[name] += counts
return total_category_counts