Skip to content

Commit

Permalink
Refactor OSITrace further and add optional caching
Browse files Browse the repository at this point in the history
Also renames the module file to fit PEP-8 conventions.

Signed-off-by: Pierre R. Mai <pmai@pmsf.de>
  • Loading branch information
pmai committed Mar 6, 2024
1 parent b424ef9 commit 5efd7e7
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 137 deletions.
131 changes: 0 additions & 131 deletions osi3trace/OSITrace.py

This file was deleted.

2 changes: 1 addition & 1 deletion osi3trace/osi2read.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
python3 osi2read.py -d trace.osi -o myreadableosifile
"""

from osi3trace.OSITrace import OSITrace
from osi3trace.osi_trace import OSITrace
import argparse
import pathlib

Expand Down
168 changes: 168 additions & 0 deletions osi3trace/osi_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Module to handle and manage OSI trace files.
"""
import lzma
import struct

from osi3.osi_sensorview_pb2 import SensorView
from osi3.osi_groundtruth_pb2 import GroundTruth
from osi3.osi_sensordata_pb2 import SensorData


MESSAGES_TYPE = {
"SensorView": SensorView,
"GroundTruth": GroundTruth,
"SensorData": SensorData,
}


class OSITrace:
"""This class can import and decode OSI trace files."""

@staticmethod
def map_message_type(type_name):
"""Map the type name to the protobuf message type."""
return MESSAGES_TYPE[type_name]

def __init__(self, path=None, type_name="SensorView", cache_messages=False):
self.type = self.map_message_type(type_name)
self.file = None
self.current_index = None
self.message_offsets = None
self.read_complete = False
self.message_cache = {} if cache_messages else None
self._header_length = 4
if path:
self.from_file(path, type_name, cache_messages)

def from_file(self, path, type_name="SensorView", cache_messages=False):
"""Import a trace from a file"""
self.type = self.map_message_type(type_name)

if path.lower().endswith((".lzma", ".xz")):
self.file = lzma.open(path, "rb")
else:
self.file = open(path, "rb")

self.read_complete = False
self.current_index = 0
self.message_offsets = [0]
self.message_cache = {} if cache_messages else None

def retrieve_offsets(self, limit=None):
"""Retrieve the offsets of the messages from the file."""
if not self.read_complete:
self.current_index = len(self.message_offsets) - 1
self.file.seek(self.message_offsets[-1], 0)
while (
not self.read_complete and not limit or len(self.message_offsets) <= limit
):
self.retrieve_message(skip=True)
return self.message_offsets

def retrieve_message(self, index=None, skip=False):
"""Retrieve the next message from the file at the current position or given index, or skip it if skip is true."""
if index is not None:
self.current_index = index
self.file.seek(self.message_offsets[index], 0)
if self.message_cache is not None and self.current_index in self.message_cache:
message = self.message_cache[self.current_index]
self.current_index += 1
if self.current_index == len(self.message_offsets):
self.file.seek(0, 2)
else:
self.file.seek(self.message_offsets[self.current_index], 0)
if skip:
return self.message_offsets[self.current_index]
else:
return message
start = self.file.tell()
header = self.file.read(self._header_length)
if len(header) < self._header_length:
if start == self.message_offsets[-1]:
self.message_offsets.pop()
self.read_complete = True
self.file.seek(start, 0)
return None
message_length = struct.unpack("<L", header)[0]
if skip:
new_pos = self.file.seek(message_length, 1)
if new_pos - start < message_length + self._header_length:
if start == self.message_offsets[-1]:
self.message_offsets.pop()
self.read_complete = True
self.file.seek(start, 0)
return None
self.current_index += 1
if start == self.message_offsets[-1]:
self.message_offsets.append(new_pos)
return new_pos
message_data = self.file.read(message_length)
if len(message_data) < message_length:
if start == self.message_offsets[-1]:
self.message_offsets.pop()
self.read_complete = True
self.file.seek(start, 0)
return None
self.current_index += 1
message = self.type()
message.ParseFromString(message_data)
if start == self.message_offsets[-1]:
if self.message_cache is not None:
self.message_cache[len(self.message_offsets) - 1] = message
self.message_offsets.append(self.file.tell())
return message

def restart(self, index=None):
"""Restart the reading of the file from the beginning or from a given index."""
self.current_index = index if index else 0
self.file.seek(self.message_offsets[self.current_index], 0)

def __iter__(self):
while message := self.retrieve_message():
yield message

def get_message_by_index(self, index):
"""
Get a message by its index.
"""
if index >= len(self.message_offsets):
self.retrieve_offsets(index)
if self.message_cache is not None and index in self.message_cache:
return self.message_cache[index]
return self.retrieve_message(index=index)

def get_messages(self):
"""
Yield an iterator over all messages in the file.
"""
return self.get_messages_in_index_range(0, None)

def get_messages_in_index_range(self, begin, end):
"""
Yield an iterator over messages of indexes between begin and end included.
"""
if begin >= len(self.message_offsets):
self.retrieve_offsets(begin)
self.restart(begin)
current = begin
while end is None or current < end:
if self.message_cache is not None and current in self.message_cache:
yield self.message_cache[current]
else:
message = self.retrieve_message()
if message is None:
break
yield message
current += 1

def close(self):
if self.file:
self.file.close()
self.file = None
self.current_index = None
self.message_cache = None
self.message_offsets = None
self.read_complete = False
self.read_limit = None
self.type = None
11 changes: 6 additions & 5 deletions tests/test_osi_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile
import unittest

from format.OSITrace import OSITrace
from osi3trace.osi_trace import OSITrace
from osi3.osi_sensorview_pb2 import SensorView
import struct

Expand All @@ -14,10 +14,11 @@ def test_osi_trace(self):
path_input = os.path.join(tmpdirname, "input.osi")
create_sample(path_input)

trace = OSITrace()
trace.from_file(path=path_input)
trace.make_readable(path_output, index=1)
trace.scenario_file.close()
trace = OSITrace(path_input)
with open(path_output, "wt") as f:
for message in trace:
f.write(str(message))
trace.close()

self.assertTrue(os.path.exists(path_output))

Expand Down

0 comments on commit 5efd7e7

Please # to comment.