diff --git a/qualtran/bloqs/arithmetic/subtraction.py b/qualtran/bloqs/arithmetic/subtraction.py index ef5cb5199..f06255708 100644 --- a/qualtran/bloqs/arithmetic/subtraction.py +++ b/qualtran/bloqs/arithmetic/subtraction.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union import numpy as np @@ -40,6 +39,7 @@ from qualtran.bloqs.bookkeeping import Allocate, Cast, Free from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT from qualtran.drawing import Text +from qualtran.simulation.classical_sim import add_ints if TYPE_CHECKING: from qualtran.drawing import WireSymbol @@ -270,10 +270,15 @@ def signature(self): def on_classical_vals( self, a: 'ClassicalValT', b: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: - unsigned = isinstance(self.dtype, (QUInt, QMontgomeryUInt)) - bitsize = self.dtype.bitsize - N = 2**bitsize if unsigned else 2 ** (bitsize - 1) - return {'a': a, 'b': int(math.fmod(b - a, N))} + return { + 'a': a, + 'b': add_ints( + int(b), + -int(a), + num_bits=int(self.dtype.bitsize), + is_signed=isinstance(self.dtype, QInt), + ), + } def wire_symbol( self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() diff --git a/qualtran/bloqs/arithmetic/subtraction_test.py b/qualtran/bloqs/arithmetic/subtraction_test.py index 40e2bc6fd..fe9a5cb48 100644 --- a/qualtran/bloqs/arithmetic/subtraction_test.py +++ b/qualtran/bloqs/arithmetic/subtraction_test.py @@ -160,3 +160,12 @@ def test_subtract_from_bloq_decomposition(): want[(a << 4) | c][a_b] = 1 got = gate.tensor_contract() np.testing.assert_allclose(got, want) + + +@pytest.mark.parametrize('bitsize', range(2, 5)) +def test_subtractfrom_classical_action(bitsize): + dtype = QInt(bitsize) + blq = SubtractFrom(dtype) + qlt_testing.assert_consistent_classical_action( + blq, a=tuple(dtype.get_classical_domain()), b=tuple(dtype.get_classical_domain()) + ) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.py b/qualtran/bloqs/mod_arithmetic/mod_addition.py index cb7ab6396..db4fe7dba 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.py @@ -87,7 +87,17 @@ def signature(self) -> 'Signature': def on_classical_vals( self, x: 'ClassicalValT', y: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: - return {'x': x, 'y': (x + y) % self.mod} + if not (0 <= x < self.mod): + raise ValueError( + f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + ) + if not (0 <= y < self.mod): + raise ValueError( + f'{y=} is outside the valid interval for modular addition [0, {self.mod})' + ) + + y = (x + y) % self.mod + return {'x': x, 'y': y} def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): @@ -307,6 +317,12 @@ def on_classical_vals( return {'ctrl': 0, 'x': x} assert ctrl == 1, 'Bad ctrl value.' + + if not (0 <= x < self.mod): + raise ValueError( + f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + ) + x = (x + self.k) % self.mod return {'ctrl': ctrl, 'x': x} @@ -492,7 +508,17 @@ def on_classical_vals( if ctrl != self.cv: return {'ctrl': ctrl, 'x': x, 'y': y} - return {'ctrl': ctrl, 'x': x, 'y': (x + y) % self.mod} + if not (0 <= x < self.mod): + raise ValueError( + f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + ) + if not (0 <= y < self.mod): + raise ValueError( + f'{y=} is outside the valid interval for modular addition [0, {self.mod})' + ) + + y = (x + y) % self.mod + return {'ctrl': ctrl, 'x': x, 'y': y} def build_composite_bloq( self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py index 455670bb5..bd5a11f4f 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py @@ -208,3 +208,15 @@ def test_cmod_add_complexity_vs_ref(): def test_mod_add_classical_action(bitsize, prime): b = ModAdd(bitsize, prime) assert_consistent_classical_action(b, x=range(prime), y=range(prime)) + + +def test_cmodadd_tensor(): + blq = CModAddK(bitsize=4, mod=7, k=1) + want = np.zeros((7, 7)) + for i in range(7): + j = (i + 1) % 7 + want[j, i] = 1 + + tn = blq.tensor_contract() + np.testing.assert_allclose(tn[:7, :7], np.eye(7)) # ctrl = 0 + np.testing.assert_allclose(tn[16 : 16 + 7, 16 : 16 + 7], want) # ctrl = 1