From 5efd7e792e14cbc2fae4759ceef5d7e0d8dc0d9e Mon Sep 17 00:00:00 2001 From: "Pierre R. Mai" Date: Thu, 29 Feb 2024 12:31:55 +0100 Subject: [PATCH] Refactor OSITrace further and add optional caching Also renames the module file to fit PEP-8 conventions. Signed-off-by: Pierre R. Mai --- osi3trace/OSITrace.py | 131 ------------------------------- osi3trace/osi2read.py | 2 +- osi3trace/osi_trace.py | 168 ++++++++++++++++++++++++++++++++++++++++ tests/test_osi_trace.py | 11 +-- 4 files changed, 175 insertions(+), 137 deletions(-) delete mode 100644 osi3trace/OSITrace.py create mode 100644 osi3trace/osi_trace.py diff --git a/osi3trace/OSITrace.py b/osi3trace/OSITrace.py deleted file mode 100644 index dd0baa91c..000000000 --- a/osi3trace/OSITrace.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Module to handle and manage OSI trace files. -""" -import lzma -import struct - -from osi3.osi_sensorview_pb2 import SensorView -from osi3.osi_groundtruth_pb2 import GroundTruth -from osi3.osi_sensordata_pb2 import SensorData - - -MESSAGES_TYPE = { - "SensorView": SensorView, - "GroundTruth": GroundTruth, - "SensorData": SensorData, -} - - -class OSITrace: - """This class can import and decode OSI trace files.""" - - @staticmethod - def map_message_type(type_name): - """Map the type name to the protobuf message type.""" - return MESSAGES_TYPE[type_name] - - def __init__(self, path=None, type_name="SensorView"): - self.type = self.map_message_type(type_name) - self.file = None - self.message_offsets = None - self.read_complete = False - self.read_limit = None - self._header_length = 4 - if path: - self.from_file(path, type_name) - - def from_file(self, path, type_name="SensorView"): - """Import a trace from a file""" - self.type = self.map_message_type(type_name) - - if path.lower().endswith((".lzma", ".xz")): - self.file = lzma.open(path, "rb") - else: - self.file = open(path, "rb") - - self.read_complete = False - self.read_limit = 0 - self.message_offsets = [0] - - def retrieve_offsets(self, limit=None): - """Retrieve the offsets of the messages from the file.""" - if not self.read_complete: - self.file.seek(self.read_limit, 0) - while not self.read_complete and not limit or len(self.message_offsets) < limit: - self.retrieve_message(skip=True) - return self.message_offsets - - def read_message(self, offset=None, skip=False): - """Read a message from the file at the given offset.""" - if offset: - self.file.seek(offset, 0) - message = self.type() - header = self.file.read(self._header_length) - if len(header) < self._header_length: - return None - message_length = struct.unpack(" len(self.message_offsets): - self.retrieve_offsets(index) - return self.read_message(self.message_offsets[index]) - - def get_messages(self): - return self.get_messages_in_index_range(0, None) - - def get_messages_in_index_range(self, begin, end): - """ - Yield an iterator over messages of indexes between begin and end included. - """ - if begin > len(self.message_offsets): - self.retrieve_offsets(begin) - self.file.seek(self.message_offsets[begin], 0) - current = begin - while end is None or current < end: - message = ( - self.retrieve_message() - if current >= len(self.message_offsets) - else self.read_message() - ) - if message is None: - break - yield message - current += 1 - - def close(self): - if self.file: - self.file.close() - self.file = None - self.message_offsets = None - self.read_complete = False - self.read_limit = None - self.type = None diff --git a/osi3trace/osi2read.py b/osi3trace/osi2read.py index 52b3becd5..c6bee861e 100644 --- a/osi3trace/osi2read.py +++ b/osi3trace/osi2read.py @@ -5,7 +5,7 @@ python3 osi2read.py -d trace.osi -o myreadableosifile """ -from osi3trace.OSITrace import OSITrace +from osi3trace.osi_trace import OSITrace import argparse import pathlib diff --git a/osi3trace/osi_trace.py b/osi3trace/osi_trace.py new file mode 100644 index 000000000..49453c660 --- /dev/null +++ b/osi3trace/osi_trace.py @@ -0,0 +1,168 @@ +""" +Module to handle and manage OSI trace files. +""" +import lzma +import struct + +from osi3.osi_sensorview_pb2 import SensorView +from osi3.osi_groundtruth_pb2 import GroundTruth +from osi3.osi_sensordata_pb2 import SensorData + + +MESSAGES_TYPE = { + "SensorView": SensorView, + "GroundTruth": GroundTruth, + "SensorData": SensorData, +} + + +class OSITrace: + """This class can import and decode OSI trace files.""" + + @staticmethod + def map_message_type(type_name): + """Map the type name to the protobuf message type.""" + return MESSAGES_TYPE[type_name] + + def __init__(self, path=None, type_name="SensorView", cache_messages=False): + self.type = self.map_message_type(type_name) + self.file = None + self.current_index = None + self.message_offsets = None + self.read_complete = False + self.message_cache = {} if cache_messages else None + self._header_length = 4 + if path: + self.from_file(path, type_name, cache_messages) + + def from_file(self, path, type_name="SensorView", cache_messages=False): + """Import a trace from a file""" + self.type = self.map_message_type(type_name) + + if path.lower().endswith((".lzma", ".xz")): + self.file = lzma.open(path, "rb") + else: + self.file = open(path, "rb") + + self.read_complete = False + self.current_index = 0 + self.message_offsets = [0] + self.message_cache = {} if cache_messages else None + + def retrieve_offsets(self, limit=None): + """Retrieve the offsets of the messages from the file.""" + if not self.read_complete: + self.current_index = len(self.message_offsets) - 1 + self.file.seek(self.message_offsets[-1], 0) + while ( + not self.read_complete and not limit or len(self.message_offsets) <= limit + ): + self.retrieve_message(skip=True) + return self.message_offsets + + def retrieve_message(self, index=None, skip=False): + """Retrieve the next message from the file at the current position or given index, or skip it if skip is true.""" + if index is not None: + self.current_index = index + self.file.seek(self.message_offsets[index], 0) + if self.message_cache is not None and self.current_index in self.message_cache: + message = self.message_cache[self.current_index] + self.current_index += 1 + if self.current_index == len(self.message_offsets): + self.file.seek(0, 2) + else: + self.file.seek(self.message_offsets[self.current_index], 0) + if skip: + return self.message_offsets[self.current_index] + else: + return message + start = self.file.tell() + header = self.file.read(self._header_length) + if len(header) < self._header_length: + if start == self.message_offsets[-1]: + self.message_offsets.pop() + self.read_complete = True + self.file.seek(start, 0) + return None + message_length = struct.unpack("= len(self.message_offsets): + self.retrieve_offsets(index) + if self.message_cache is not None and index in self.message_cache: + return self.message_cache[index] + return self.retrieve_message(index=index) + + def get_messages(self): + """ + Yield an iterator over all messages in the file. + """ + return self.get_messages_in_index_range(0, None) + + def get_messages_in_index_range(self, begin, end): + """ + Yield an iterator over messages of indexes between begin and end included. + """ + if begin >= len(self.message_offsets): + self.retrieve_offsets(begin) + self.restart(begin) + current = begin + while end is None or current < end: + if self.message_cache is not None and current in self.message_cache: + yield self.message_cache[current] + else: + message = self.retrieve_message() + if message is None: + break + yield message + current += 1 + + def close(self): + if self.file: + self.file.close() + self.file = None + self.current_index = None + self.message_cache = None + self.message_offsets = None + self.read_complete = False + self.read_limit = None + self.type = None diff --git a/tests/test_osi_trace.py b/tests/test_osi_trace.py index c832b63d7..8a9a430ec 100644 --- a/tests/test_osi_trace.py +++ b/tests/test_osi_trace.py @@ -2,7 +2,7 @@ import tempfile import unittest -from format.OSITrace import OSITrace +from osi3trace.osi_trace import OSITrace from osi3.osi_sensorview_pb2 import SensorView import struct @@ -14,10 +14,11 @@ def test_osi_trace(self): path_input = os.path.join(tmpdirname, "input.osi") create_sample(path_input) - trace = OSITrace() - trace.from_file(path=path_input) - trace.make_readable(path_output, index=1) - trace.scenario_file.close() + trace = OSITrace(path_input) + with open(path_output, "wt") as f: + for message in trace: + f.write(str(message)) + trace.close() self.assertTrue(os.path.exists(path_output))