Skip to content

Commit

Permalink
Moving endianness back out of base class.
Browse files Browse the repository at this point in the history
Nice refactor to DtypeArray so it now is composed of a DtypeSingle with an items.
  • Loading branch information
scott-griffiths committed Feb 11, 2025
1 parent 5c4f97c commit f4d35b8
Showing 1 changed file with 34 additions and 61 deletions.
95 changes: 34 additions & 61 deletions bitformat/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class Dtype(abc.ABC):
"""

_name: DtypeName
_endianness: Endianness

def __new__(cls, token: str | None = None, /) -> Self:
if token is None:
Expand Down Expand Up @@ -136,12 +135,6 @@ def unpack(self, b: BitsType, /):
"""
...


@property
def endianness(self) -> Endianness:
"""The endianness of the data type."""
return self._endianness

@property
@abc.abstractmethod
def bit_length(self) -> int:
Expand Down Expand Up @@ -178,13 +171,19 @@ class DtypeSingle(Dtype):
_size: int
_bit_length: int
_definition: DtypeDefinition
_endianness: Endianness

@property
@override
@final
def name(self) -> DtypeName:
return self._definition.name

@property
def endianness(self) -> Endianness:
"""The endianness of the data type."""
return self._endianness

@property
def return_type(self) -> Any:
"""The type of the value returned by the parse method, such as ``int``, ``float`` or ``str``."""
Expand All @@ -196,12 +195,13 @@ def _create(cls, definition: DtypeDefinition, size: int,
endianness: Endianness = Endianness.UNSPECIFIED) -> Self:
x = DtypeSingle.__new__(DtypeSingle)
x._definition = definition
x._bit_length = x._size = size
if definition.bits_per_character is not None:
x._bit_length *= definition.bits_per_character
little_endian: bool = endianness == Endianness.LITTLE or (
endianness == Endianness.NATIVE and bitformat.byteorder == "little"
)
x._size = size
if definition.bits_per_character is None:
x._bit_length = size
else:
x._bit_length = size * definition.bits_per_character
little_endian = (endianness == Endianness.LITTLE or
(endianness == Endianness.NATIVE and bitformat.byteorder == "little"))
x._endianness = endianness
x._get_fn = (
(lambda b: definition.get_fn(b.byte_swap()))
Expand Down Expand Up @@ -310,53 +310,26 @@ def size(self) -> int:

class DtypeArray(Dtype):

_size: int
_dtype_single: DtypeSingle
_items: int | None
_bits_per_item: int
_definition: DtypeDefinition

@property
@override
@final
def name(self) -> DtypeName:
return self._definition.name
return self._dtype_single.name

@property
def endianness(self) -> Endianness:
"""The endianness of the data type."""
return self._dtype_single.endianness

@classmethod
@functools.lru_cache(CACHE_SIZE)
def _create(cls, definition: DtypeDefinition, size: int, items: int = 1,
endianness: Endianness = Endianness.UNSPECIFIED,) -> Self:
x = super().__new__(cls)
x._definition = definition
x._dtype_single = DtypeSingle._create(definition, size, endianness)
x._items = items
x._bits_per_item = x._size = size
if definition.bits_per_character is not None:
x._bits_per_item *= definition.bits_per_character
little_endian: bool = endianness == Endianness.LITTLE or (
endianness == Endianness.NATIVE and bitformat.byteorder == "little"
)
x._endianness = endianness
x._get_fn = (
(lambda b: definition.get_fn(b.byte_swap()))
if little_endian
else definition.get_fn
)
if "length" in inspect.signature(definition.set_fn).parameters:
set_fn = functools.partial(definition.set_fn, length=x._bits_per_item)
else:
set_fn = definition.set_fn

def create_bits(v):
b = bitformat.Bits()
# The set_fn will do the length check for big endian too.
set_fn(b, v)
return b

def create_bits_le(v):
b = bitformat.Bits()
set_fn(b, v)
return b.byte_swap()

x._create_fn = create_bits_le if little_endian else create_bits
return x

@classmethod
Expand All @@ -381,7 +354,7 @@ def pack(self, value: Any, /) -> bitformat.Bits:
return value
if len(value) != self._items and self._items != 0:
raise ValueError(f"Expected {self._items} items, but got {len(value)}.")
return bitformat.Bits.from_joined(self._create_fn(v) for v in value)
return bitformat.Bits.from_joined(self._dtype_single._create_fn(v) for v in value)

@override
@final
Expand All @@ -392,20 +365,20 @@ def unpack(self, b: BitsType, /) -> Any | tuple[Any]:
items = self.items
if items == 0:
# For array dtypes with no items (e.g. '[u8;]') unpack as much as possible.
items = len(b) // self._bits_per_item
items = len(b) // self._dtype_single.bit_length
return tuple(
self._get_fn(b[i * self._bits_per_item : (i + 1) * self._bits_per_item])
self._dtype_single._get_fn(b[i * self._dtype_single.bit_length : (i + 1) * self._dtype_single.bit_length])
for i in range(items)
)

@override
@final
def __str__(self) -> str:
hide_length = self.size == 0 or self._definition.allowed_sizes.only_one_value()
hide_length = self.size == 0 or self._dtype_single._definition.allowed_sizes.only_one_value()
size_str = "" if hide_length else str(self.size)
endianness = "" if self._endianness == Endianness.UNSPECIFIED else "_" + self._endianness.value
endianness = "" if self.endianness == Endianness.UNSPECIFIED else "_" + self.endianness.value
items_str = "" if self._items == 0 else f" {self._items}"
return f"[{self._definition.name}{endianness}{size_str};{items_str}]"
return f"[{self.name}{endianness}{size_str};{items_str}]"

@override
@final
Expand All @@ -414,23 +387,23 @@ def __eq__(self, other: Any) -> bool:
other = Dtype.from_string(other)
if isinstance(other, Dtype):
return (
self._definition.name == other._definition.name
and self._size == other._size
and self._items == other._items
and self._endianness == other._endianness
self.name == other.name
and self.size == other.size
and self.items == other.items
and self.endianness == other.endianness
)
return False

@override
@final
def __hash__(self) -> int:
return hash((self._definition.name.value, self._size, self._items))
return hash((self._dtype_single, self._items))

@override
@final
@property
def bit_length(self) -> int:
return self._bits_per_item * self._items
return self._dtype_single.bit_length * self._items

@property
def size(self) -> int:
Expand All @@ -443,7 +416,7 @@ def size(self) -> int:
See also :attr:`bit_length`.
"""
return self._size
return self._dtype_single.size

@property
def items(self) -> int:
Expand Down

0 comments on commit f4d35b8

Please # to comment.