diff --git a/bitformat/_dtypes.py b/bitformat/_dtypes.py index 33d7685..5d80543 100644 --- a/bitformat/_dtypes.py +++ b/bitformat/_dtypes.py @@ -429,13 +429,9 @@ def __init__( def allowed_size_checked_get_fn(bs): if len(bs) not in self.allowed_sizes: if self.allowed_sizes.only_one_value(): - raise ValueError( - f"'{self.name}' dtypes must have a size of {self.allowed_sizes.values[0]}, but received a size of {len(bs)}." - ) + raise ValueError(f"'{self.name}' dtypes must have a size of {self.allowed_sizes.values[0]}, but received a size of {len(bs)}.") else: - raise ValueError( - f"'{self.name}' dtypes must have a size in {self.allowed_sizes}, but received a size of {len(bs)}." - ) + raise ValueError(f"'{self.name}' dtypes must have a size in {self.allowed_sizes}, but received a size of {len(bs)}.") return get_fn(bs) self.get_fn = ( @@ -445,9 +441,7 @@ def allowed_size_checked_get_fn(bs): self.get_fn = get_fn # Interpret everything if bits_per_character is not None: if bitlength2chars_fn is not None: - raise ValueError( - "You shouldn't specify both a bits_per_character and a bitlength2chars_fn." - ) + raise ValueError("You shouldn't specify both a bits_per_character and a bitlength2chars_fn.") def bitlength2chars_fn(x): return x // bits_per_character @@ -473,13 +467,9 @@ def sanitize(self, size: int, endianness: Endianness) -> tuple[int, Endianness]: ) if endianness != Endianness.UNSPECIFIED: if not self.endianness_variants: - raise ValueError( - f"The '{self.name}' dtype does not support endianness variants, but '{endianness.value}' was specified." - ) + raise ValueError(f"The '{self.name}' dtype does not support endianness variants, but '{endianness.value}' was specified.") if size % 8 != 0: - raise ValueError( - f"Endianness can only be specified for whole-byte dtypes, but '{self.name}' has a size of {size} bits." - ) + raise ValueError(f"Endianness can only be specified for whole-byte dtypes, but '{self.name}' has a size of {size} bits.") return size, endianness def get_single_dtype( @@ -495,7 +485,8 @@ def get_array_dtype( size, endianness = self.sanitize(size, endianness) d = Dtype._create(self, size, True, items, endianness) if size == 0: - raise ValueError(f"Array dtypes must have a size specified. Got '{d}'. Note that the number of items in the array dtype can be unknown, but the dtype of each item must have a known size.") + raise ValueError(f"Array dtypes must have a size specified. Got '{d}'. " + f"Note that the number of items in the array dtype can be unknown, but the dtype of each item must have a known size.") return d def __repr__(self) -> str: @@ -547,16 +538,12 @@ def add_dtype(cls, definition: DtypeDefinition): def fget_be(b): if len(b) % 8 != 0: - raise ValueError( - f"Cannot use endianness modifer for non whole-byte data. Got length of {len(b)} bits." - ) + raise ValueError(f"Cannot use endianness modifer for non whole-byte data. Got length of {len(b)} bits.") return definition.get_fn(b) def fget_le(b): if len(b) % 8 != 0: - raise ValueError( - f"Cannot use endianness modifer for non whole-byte data. Got length of {len(b)} bits." - ) + raise ValueError(f"Cannot use endianness modifer for non whole-byte data. Got length of {len(b)} bits.") return definition.get_fn(b.byte_swap()) fget_ne = fget_le if byteorder == "little" else fget_be @@ -598,9 +585,7 @@ def get_single_dtype( except KeyError: aliases = {"int": "i", "uint": "u", "float": "f"} extra = f"Did you mean '{aliases[name]}'? " if name in aliases else "" - raise ValueError( - f"Unknown Dtype name '{name}'. {extra}Names available: {list(cls.name_to_def.keys())}." - ) + raise ValueError(f"Unknown Dtype name '{name}'. {extra}Names available: {list(cls.name_to_def.keys())}.") else: return definition.get_single_dtype(size, endianness) @@ -616,9 +601,7 @@ def get_array_dtype( try: definition = cls.name_to_def[name] except KeyError: - raise ValueError( - f"Unknown Dtype name '{name}'. Names available: {list(cls.name_to_def.keys())}." - ) + raise ValueError(f"Unknown Dtype name '{name}'. Names available: {list(cls.name_to_def.keys())}.") else: d = definition.get_array_dtype(size, items, endianness) return d @@ -676,9 +659,7 @@ def from_string(cls, token: str, /) -> DtypeWithExpression: token = "".join(token.split()) # Remove whitespace if token.startswith("[") and token.endswith("]"): if (p := token.find(";")) == -1: - raise ValueError( - f"Array Dtype strings should be of the form '[dtype; items]'. Got '{token}'." - ) + raise ValueError(f"Array Dtype strings should be of the form '[dtype; items]'. Got '{token}'.") t = token[p + 1 : -1] try: items = int(t) if t else 0 @@ -717,56 +698,26 @@ def evaluate(self, vars_: dict[str, Any]) -> Dtype: return self.base_dtype if self.base_dtype.is_array: name = self.base_dtype.name - size = ( - self.size_expression.evaluate(vars_) - if (self.size_expression and vars_) - else self.base_dtype.size - ) - items = ( - self.items_expression.evaluate(vars_) - if (self.items_expression and vars_) - else self.base_dtype.items - ) + size = self.size_expression.evaluate(vars_) if (self.size_expression and vars_) else self.base_dtype.size + items = self.items_expression.evaluate(vars_) if (self.items_expression and vars_) else self.base_dtype.items endianness = self.base_dtype.endianness return Register().get_array_dtype(name, size, items, endianness) else: name = self.base_dtype.name - size = ( - self.size_expression.evaluate(vars_) - if (self.size_expression and vars_) - else self.base_dtype.size - ) + size = self.size_expression.evaluate(vars_) if (self.size_expression and vars_) else self.base_dtype.size endianness = self.base_dtype.endianness return Register().get_single_dtype(name, size, endianness) def __str__(self) -> str: - hide_size = Register().name_to_def[ - self.base_dtype.name - ].allowed_sizes.only_one_value() or ( - self.base_dtype.size == 0 and self.size_expression is None - ) - size_str = ( - "" - if hide_size - else ( - self.size_expression - if self.size_expression - else str(self.base_dtype.size) - ) - ) + only_one_value = Register().name_to_def[self.base_dtype.name].allowed_sizes.only_one_value() + no_value_given = self.base_dtype.size == 0 and self.size_expression is None + hide_size = only_one_value or no_value_given + size_str = "" if hide_size else (self.size_expression if self.size_expression else str(self.base_dtype.size)) if not self.base_dtype.is_array: return f"{self.base_dtype.name}{self.base_dtype.endianness.value}{size_str}" hide_items = self.base_dtype.items == 0 and self.items_expression is None - items_str = ( - "" - if hide_items - else ( - self.items_expression - if self.items_expression - else str(self.base_dtype.items) - ) - ) - return f"[{self.base_dtype.name}{self.base_dtype.endianness.value}{size_str}; {items_str}]" + items_str = "" if hide_items else (" " + self.items_expression if self.items_expression else " " + str(self.base_dtype.items)) + return f"[{self.base_dtype.name}{self.base_dtype.endianness.value}{size_str};{items_str}]" class DtypeTuple: @@ -807,9 +758,7 @@ def from_string(cls, s: str, /) -> DtypeTuple: def pack(self, values: Sequence[Any]) -> bitformat.Bits: if len(values) != len(self): raise ValueError(f"Expected {len(self)} values, but got {len(values)}.") - return bitformat.Bits.from_joined( - dtype.pack(value) for dtype, value in zip(self._dtypes, values) - ) + return bitformat.Bits.from_joined(dtype.pack(value) for dtype, value in zip(self._dtypes, values)) def unpack( self, @@ -823,9 +772,7 @@ def unpack( """ b = bitformat.Bits._from_any(b) if self.bit_length > len(b): - raise ValueError( - f"{self!r} is {self.bit_length} bits long, but only got {len(b)} bits to unpack." - ) + raise ValueError(f"{self!r} is {self.bit_length} bits long, but only got {len(b)} bits to unpack.") vals = [] pos = 0 for dtype in self: diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 6117a61..b6a1531 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -276,6 +276,7 @@ def test_str(): def test_unpacking_dtype_array_with_no_length(): d = Dtype('[bool;]') + assert str(d) == '[bool;]' assert d.unpack('0b110') == (True, True, False) assert Dtype('[u8;]').unpack('0x0001f') == (0, 1) diff --git a/tests/test_field.py b/tests/test_field.py index 53463ef..3095017 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -223,6 +223,11 @@ def test_unpack(): with pytest.raises(ValueError): _ = f.unpack() +def test_unpack_with_unknown_items(): + f = Field("[i9; ]") + assert str(f) == "[i9;]" + f.pack([5, -5, 0, 100]) + assert f.unpack() == (5, -5, 0, 100) def test_field_with_comment(): f = Field.from_params("u8", name="x", comment=" This is a comment ")