-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharithmetic.py
79 lines (55 loc) · 2.01 KB
/
arithmetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from collections import defaultdict
from decimal import Decimal, getcontext
getcontext().prec = 50
def calculate_frequencies(text: str) -> dict[str, int]:
frequencies: defaultdict[str, int] = defaultdict(int)
for char in text:
frequencies[char] += 1
return frequencies
def build_intervals(frequencies: dict[str, int]) -> dict[str, tuple[Decimal, Decimal]]:
total = sum(frequencies.values())
intervals: dict[str, tuple[Decimal, Decimal]] = {}
low = Decimal(0)
for char, freq in frequencies.items():
high = low + Decimal(freq) / Decimal(total)
intervals[char] = (low, high)
low = high
return intervals
def arithmetic_encode(
text: str, intervals: dict[str, tuple[Decimal, Decimal]]
) -> Decimal:
low = Decimal(0)
high = Decimal(1)
for char in text:
char_low, char_high = intervals[char]
range_ = high - low
high = low + range_ * char_high
low = low + range_ * char_low
return (low + high) / 2
def arithmetic_decode(
encoded: Decimal, length: int, intervals: dict[str, tuple[Decimal, Decimal]]
) -> str:
low = Decimal(0)
high = Decimal(1)
decoded_text = ""
for _ in range(length):
range_ = high - low
value = (encoded - low) / range_
for char, (char_low, char_high) in intervals.items():
if char_low <= value < char_high:
decoded_text += char
high = low + range_ * char_high
low = low + range_ * char_low
break
return decoded_text
if __name__ == "__main__":
original = "HELLO"
print(f"Original: {original}")
frequencies = calculate_frequencies(original)
intervals = build_intervals(frequencies)
encoded = arithmetic_encode(original, intervals)
print(f"Encoded: {encoded}")
decoded = arithmetic_decode(encoded, len(original), intervals)
print(f"Decoded: {decoded}")
if original != decoded:
raise ValueError("Decoded data does not match original data")