diff --git a/README.md b/README.md index 2d3f427..e995b78 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/) -`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. +`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector and density matrix simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. ## Installation diff --git a/docs/index.md b/docs/index.md index 87c805e..1ddbe5c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ # Welcome to horqrux -**horqrux** is a state vector simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/). +**horqrux** is a state vector and density matrix simulator designed for quantum machine learning written in [JAX](https://jax.readthedocs.io/). ## Setup @@ -110,10 +110,11 @@ from operator import add from typing import Any, Callable from uuid import uuid4 -from horqrux.circuit import QuantumCircuit, hea, expectation +from horqrux import expectation +from horqrux import Z, RX, RY, NOT, zero_state, apply_gate +from horqrux.circuit import QuantumCircuit, hea from horqrux.primitive import Primitive from horqrux.parametric import Parametric -from horqrux import Z, RX, RY, NOT, zero_state, apply_gate from horqrux.utils import DiffMode diff --git a/docs/noise.md b/docs/noise.md new file mode 100644 index 0000000..b90e7d3 --- /dev/null +++ b/docs/noise.md @@ -0,0 +1,101 @@ +## Digital Noise + +In the description of closed quantum systems, the complete quantum state is a pure quantum state represented by a state vector $|\psi \rangle $. + +However, this description is not sufficient to describe open quantum systems. When the system interacts with its surrounding environment, it transitions into a mixed state where quantum information is no longer entirely contained in a single state vector but is distributed probabilistically. + +To address these more general cases, we consider a probabilistic combination $p_i$ of possible pure states $|\psi_i \rangle$. Thus, the system is described by a density matrix $\rho$ defined as follows: + +$$ +\rho = \sum_i p_i |\psi_i\rangle \langle \psi_i| +$$ + +The transformations of the density operator of an open quantum system interacting with its noisy environment are represented by the super-operator $S: \rho \rightarrow S(\rho)$, often referred to as a quantum channel. +Quantum channels, due to the conservation of the probability distribution, must be CPTP (Completely Positive and Trace Preserving). Any CPTP super-operator can be written in the following form: + +$$ +S(\rho) = \sum_i K_i \rho K^{\dagger}_i +$$ + +Where $K_i$ are Kraus operators satisfying the closure property $\sum_i K_i K^{\dagger}_i = \mathbb{I}$. As noise is the result of system interactions with its environment, it is therefore possible to simulate noisy quantum circuit with noise type gates. + +Thus, `horqrux` implements a large selection of single qubit noise gates such as: + +- The bit flip channel defined as: $\textbf{BitFlip}(\rho) =(1-p) \rho + p X \rho X^{\dagger}$ +- The phase flip channel defined as: $\textbf{PhaseFlip}(\rho) = (1-p) \rho + p Z \rho Z^{\dagger}$ +- The depolarizing channel defined as: $\textbf{Depolarizing}(\rho) = (1-p) \rho + \frac{p}{3} (X \rho X^{\dagger} + Y \rho Y^{\dagger} + Z \rho Z^{\dagger})$ +- The pauli channel defined as: $\textbf{PauliChannel}(\rho) = (1-p_x-p_y-p_z) \rho + + p_x X \rho X^{\dagger} + + p_y Y \rho Y^{\dagger} + + p_z Z \rho Z^{\dagger}$ +- The amplitude damping channel defined as: $\textbf{AmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$ + with: + $\begin{equation*} + K_{0} \ =\begin{pmatrix} + 1 & 0\\ + 0 & \sqrt{1-\ \gamma } + \end{pmatrix} ,\ K_{1} \ =\begin{pmatrix} + 0 & \sqrt{\ \gamma }\\ + 0 & 0 + \end{pmatrix} + \end{equation*}$ +- The phase damping channel defined as: $\textbf{PhaseDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger}$ + with: + $\begin{equation*} + K_{0} \ =\begin{pmatrix} + 1 & 0\\ + 0 & \sqrt{1-\ \gamma } + \end{pmatrix}, \ K_{1} \ =\begin{pmatrix} + 0 & 0\\ + 0 & \sqrt{\ \gamma } + \end{pmatrix} + \end{equation*}$ +* The generalize amplitude damping channel is defined as: $\textbf{GeneralizedAmplitudeDamping}(\rho) = K_0 \rho K_0^{\dagger} + K_1 \rho K_1^{\dagger} + K_2 \rho K_2^{\dagger} + K_3 \rho K_3^{\dagger}$ + with: +$\begin{cases} +K_{0} \ =\sqrt{p} \ \begin{pmatrix} +1 & 0\\ +0 & \sqrt{1-\ \gamma } +\end{pmatrix} ,\ K_{1} \ =\sqrt{p} \ \begin{pmatrix} +0 & 0\\ +0 & \sqrt{\ \gamma } +\end{pmatrix} \\ +K_{2} \ =\sqrt{1\ -p} \ \begin{pmatrix} +\sqrt{1-\ \gamma } & 0\\ +0 & 1 +\end{pmatrix} ,\ K_{3} \ =\sqrt{1-p} \ \begin{pmatrix} +0 & 0\\ +\sqrt{\ \gamma } & 0 +\end{pmatrix} +\end{cases}$ + +Noise protocols can be added to gates by instantiating `DigitalNoiseInstance` providing the `DigitalNoiseType` and the `error_probability` (either float or tuple of float): + +```python exec="on" source="material-block" html="1" +from horqrux.noise import DigitalNoiseInstance, DigitalNoiseType + +noise_prob = 0.3 +AmpD = DigitalNoiseInstance(DigitalNoiseType.AMPLITUDE_DAMPING, error_probability=noise_prob) + +``` + +Then a gate can be instantiated by providing a tuple of `DigitalNoiseInstance` instances. Let’s show this through the simulation of a realistic $X$ gate. + +For instance, an $X$ gate flips the state of the qubit: $X|0\rangle = |1\rangle$. In practice, it's common for the target qubit to stay in its original state after applying $X$ due to the interactions between it and its environment. The possibility of failure can be represented by the `BitFlip` subclass of `DigitalNoiseInstance`, which flips the state again after the application of the $X$ gate, returning it to its original state with a probability `1 - gate_fidelity`. + +```python exec="on" source="material-block" +from horqrux.api import sample +from horqrux.noise import DigitalNoiseInstance, DigitalNoiseType +from horqrux.utils import density_mat, product_state +from horqrux.primitive import X + +noise = (DigitalNoiseInstance(DigitalNoiseType.BITFLIP, 0.1),) +ops = [X(0)] +noisy_ops = [X(0, noise=noise)] +state = product_state("0") + +noiseless_samples = sample(state, ops) +noisy_samples = sample(density_mat(state), noisy_ops) +print(f"Noiseless samples: {noiseless_samples}") # markdown-exec: hide +print(f"Noiseless samples: {noisy_samples}") # markdown-exec: hide +``` diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 513c031..3e31d9d 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from .api import expectation +from .api import expectation, run, sample from .apply import apply_gate, apply_operator -from .circuit import QuantumCircuit, sample +from .circuit import QuantumCircuit from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index de7f925..d3dc138 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Tuple - from jax import Array, custom_vjp from horqrux.apply import apply_gate @@ -37,14 +35,14 @@ def adjoint_expectation( def adjoint_expectation_single_observable_fwd( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] -) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]: +) -> tuple[Array, tuple[Array, Array, list[Primitive], dict[str, float]]]: out_state = apply_gate(state, gates, values, OperationType.UNITARY) projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) return inner(out_state, projected_state).real, (out_state, projected_state, gates, values) def adjoint_expectation_single_observable_bwd( - res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array + res: tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array ) -> tuple: """Implementation of Algorithm 1 of https://arxiv.org/abs/2009.02823 which computes the vector-jacobian product in O(P) time using O(1) state vectors diff --git a/horqrux/api.py b/horqrux/api.py index 2eeb464..37e1e96 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import Counter +from functools import singledispatch from typing import Any, Optional import jax @@ -11,74 +12,126 @@ from horqrux.adjoint import adjoint_expectation from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive -from horqrux.shots import finite_shots_fwd -from horqrux.utils import DiffMode, ForwardMode, OperationType, inner +from horqrux.shots import finite_shots_fwd, to_matrix +from horqrux.utils import ( + DensityMatrix, + DiffMode, + ForwardMode, + OperationType, + State, + inner, + num_qubits, + probabilities, + sample_from_probs, +) def run( circuit: GateSequence, - state: Array, + state: State, values: dict[str, float] = dict(), -) -> Array: +) -> State: return apply_gate(state, circuit, values) def sample( - state: Array, + state: State, gates: GateSequence, values: dict[str, float] = dict(), n_shots: int = 1000, ) -> Counter: + """Sample from a quantum program. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + values (dict[str, float], optional): _description_. Defaults to dict(). + n_shots (int, optional): Parameter values.. Defaults to 1000. + + Raises: + ValueError: If n_shots < 1. + + Returns: + Counter: Bitstrings and frequencies. + """ if n_shots < 1: - raise ValueError("You can only call sample with n_shots>0.") - - wf = apply_gate(state, gates, values) - probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() - key = jax.random.PRNGKey(0) - n_qubits = len(state.shape) - # JAX handles pseudo random number generation by tracking an explicit state via a random key - # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html - samples = jax.vmap( - lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) - )(jax.random.split(key, n_shots)) - - return Counter( - { - format(k, "0{}b".format(n_qubits)): count.item() - for k, count in enumerate(jnp.bincount(samples)) - if count > 0 - } + raise ValueError("You can only sample with non-negative 'n_shots'.") + output_circuit = apply_gate(state, gates, values) + n_qubits = num_qubits(output_circuit) + if isinstance(output_circuit, DensityMatrix): + d = 2**n_qubits + output_circuit.array = output_circuit.array.reshape((d, d)) + + probs = probabilities(output_circuit) + return sample_from_probs(probs, n_qubits, n_shots) + + +@singledispatch +def _ad_expectation_single_observable( + state: Any, + observable: Primitive, + values: dict[str, float], +) -> Any: + raise NotImplementedError("_ad_expectation_single_observable is not implemented") + + +@_ad_expectation_single_observable.register +def _( + state: Array, + observable: Primitive, + values: dict[str, float], +) -> Array: + projected_state = apply_gate( + state, + observable, + values, + OperationType.UNITARY, ) + return inner(state, projected_state).real -def __ad_expectation_single_observable( - state: Array, gates: GateSequence, observable: Primitive, values: dict[str, float] +@_ad_expectation_single_observable.register +def _( + state: DensityMatrix, + observable: Primitive, + values: dict[str, float], ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. - """ - out_state = apply_gate(state, gates, values, OperationType.UNITARY) - projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) - return inner(out_state, projected_state).real + n_qubits = num_qubits(state) + mat_obs = to_matrix(observable, n_qubits, values) + d = 2**n_qubits + prod = jnp.matmul(mat_obs, state.array.reshape((d, d))) + return jnp.trace(prod, axis1=-2, axis2=-1).real def ad_expectation( - state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float] + state: State, + gates: GateSequence, + observables: list[Primitive], + values: dict[str, float], ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. + """Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + observables (list[Primitive]): List of observables. + values (dict[str, float]): Parameter values. + + Returns: + Array: Expectation values. """ outputs = [ - __ad_expectation_single_observable(state, gates, observable, values) + _ad_expectation_single_observable( + apply_gate(state, gates, values, OperationType.UNITARY), observable, values + ) for observable in observables ] return jnp.stack(outputs) def expectation( - state: Array, + state: State, gates: GateSequence, observables: list[Primitive], values: dict[str, float], @@ -87,13 +140,27 @@ def expectation( n_shots: Optional[int] = None, key: Any = jax.random.PRNGKey(0), ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' + """Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. + + Args: + state (State): Input state vector or density matrix. + gates (GateSequence): Sequence of gates. + observables (list[Primitive]): List of observables. + values (dict[str, float]): Parameter values. + diff_mode (DiffMode, optional): Differentiation mode. Defaults to DiffMode.AD. + forward_mode (ForwardMode, optional): Type of forward method. Defaults to ForwardMode.EXACT. + n_shots (Optional[int], optional): Number of shots. Defaults to None. + key (Any, optional): Random key. Defaults to jax.random.PRNGKey(0). + + Returns: + Array: Expectation values. """ if diff_mode == DiffMode.AD: return ad_expectation(state, gates, observables, values) elif diff_mode == DiffMode.ADJOINT: + if isinstance(state, DensityMatrix): + raise TypeError("Adjoint does not support density matrices.") return adjoint_expectation(state, gates, observables, values) elif diff_mode == DiffMode.GPSR: checkify.check( @@ -105,4 +172,11 @@ def expectation( ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore - return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key) + return finite_shots_fwd( + state=state, + gates=gates, + observables=observables, + values=values, + n_shots=n_shots, + key=key, + ) diff --git a/horqrux/apply.py b/horqrux/apply.py index f82cf9f..3a3e04d 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -1,24 +1,60 @@ from __future__ import annotations -from functools import reduce +from functools import partial, reduce, singledispatch from operator import add -from typing import Iterable, Tuple +from typing import Any, Iterable, Union +import jax import jax.numpy as jnp import numpy as np from jax import Array from horqrux.primitive import Primitive -from .utils import OperationType, State, _controlled, is_controlled +from .noise import NoiseProtocol +from .utils import ( + DensityMatrix, + OperationType, + State, + _controlled, + _dagger, + density_mat, + is_controlled, + permute_basis, +) +@singledispatch def apply_operator( - state: State, + state: Any, operator: Array, - target: Tuple[int, ...], - control: Tuple[int | None, ...], -) -> State: + target: tuple[int, ...], + control: tuple[Union[int, None], ...], +) -> Any: + """Apply an operator on a state or density matrix. + + Args: + state (Any): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. + + Raises: + NotImplementedError: If not implemented for given types. + + Returns: + Array: The output of the application of the operator. + """ + raise NotImplementedError("apply_operator is not implemented") + + +@apply_operator.register +def _( + state: Array, + operator: Array, + target: tuple[int, ...], + control: tuple[Union[int, None], ...], +) -> Array: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. In case of a controlled operation, the 'operator' array will be embedded into a controlled array. @@ -30,27 +66,166 @@ def apply_operator( dimension 'i' of 'state'. To restore the former order of dimensions, the affected dimensions are moved to their original positions and the state is returned. - Arguments: - state: State to operate on. - operator: Array to contract over 'state'. - target: Tuple of target qubits on which to apply the 'operator' to. - control: Tuple of control qubits. + Args: + state (Array): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. Returns: - State after applying 'operator'. + Array after applying 'operator'. """ - state_dims: Tuple[int, ...] = target + state_dims: tuple[int, ...] = target if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] - n_qubits = int(np.log2(operator.shape[1])) - operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits))) - op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) - state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims)) - new_state_dims = tuple(i for i in range(len(state_dims))) + n_qubits_op = int(np.log2(operator.shape[1])) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) + # Apply operator + new_state_dims = tuple(range(len(state_dims))) + state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims)) return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) +@apply_operator.register +def _( + state: DensityMatrix, + operator: Array, + target: tuple[int, ...], + control: tuple[Union[int, None], ...], +) -> DensityMatrix: + """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix + of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits. + In case of a controlled operation, the 'operator' array will be embedded into a controlled array. + + Args: + state (DensityMatrix): Array to operate on. + operator (Array): Array to contract over 'state'. + target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to. + control (tuple[int | None, ...]): tuple of control qubits. + + Returns: + Density matrix after applying 'operator'. + """ + state_dims: tuple[int, ...] = target + if is_controlled(control): + operator = _controlled(operator, len(control)) + state_dims = (*control, *target) # type: ignore[arg-type] + n_qubits_op = int(np.log2(operator.shape[1])) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) + op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int)) + new_state_dims = tuple(range(len(state_dims))) + + # Apply operator to density matrix: ρ' = O ρ O† + out_state = state.array + support_perm = state_dims + tuple(set(tuple(range(out_state.ndim // 2))) - set(state_dims)) + + out_state = permute_basis(out_state, support_perm, False) + out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, new_state_dims)) + + out_state = _dagger(out_state) + out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, op_in_dims)) + out_state = _dagger(out_state) + + out_state = permute_basis(out_state, support_perm, True) + return DensityMatrix(out_state) + + +def apply_kraus_operator( + kraus: Array, + array: Array, + target: tuple[int, ...], +) -> Array: + """Apply K \\rho K^\\dagger. + + Args: + kraus (Array): Kraus operator K. + state (Array): Input density matrix. + target (tuple[int, ...]): Target qubits. + + Returns: + Array: K \\rho K^\\dagger. + """ + state_dims: tuple[int, ...] = target + n_qubits = int(np.log2(kraus.size)) + kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits))) + op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int)) + + array = jnp.tensordot(a=kraus, b=array, axes=(op_dims, state_dims)) + new_state_dims = tuple(i for i in range(len(state_dims))) + array = jnp.moveaxis(a=array, source=new_state_dims, destination=state_dims) + + array = jnp.tensordot(a=kraus, b=_dagger(array), axes=(op_dims, state_dims)) + array = _dagger(array) + + return array + + +def apply_kraus_sum( + kraus_ops: Array, + array: Array, + target: tuple[int, ...], +) -> DensityMatrix: + """Apply the following evolution as a sum of Kraus operators: + .. math:: + S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger + + Args: + kraus_ops (Array): Stacked K_i. + state (Array): Input array. + target (tuple[int, ...]): Qubits the operator is defined on. + + Returns: + DensityMatrix: Output density matrix. + """ + + apply_one_kraus = jax.vmap( + partial( + apply_kraus_operator, + array=array, + target=target, + ) + ) + kraus_evol = apply_one_kraus(kraus_ops) + output_dm = jnp.sum(kraus_evol, 0) + return DensityMatrix(output_dm) + + +def apply_operator_with_noise( + state: DensityMatrix, + operator: Array, + target: tuple[int, ...], + control: tuple[Union[int, None], ...], + noise: NoiseProtocol, +) -> State: + """Evolves the input state and applies a noisy quantum channel + on the evolved state :math:`\rho`. + + The evolution is represented as a sum of Kraus operators: + .. math:: + S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger, + + Args: + state (State): Input state or density matrix. + operator (Array): Operator to apply. + target (tuple[int, ...]): Target qubits. + control (tuple[int | None, ...]): Control qubits. + noise (NoiseProtocol): The noise protocol. + + Returns: + Array: Output state or density matrix. + """ + state_gate = apply_operator(state, operator, target, control) + if noise is None: + return state_gate + else: + kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise)))) + output_dm = apply_kraus_sum(kraus_ops, state_gate.array, target) + return output_dm + + def group_by_index(gates: Iterable[Primitive]) -> Iterable[Primitive]: """Group gates together which are acting on the same qubit.""" sorted_gates = [] @@ -81,9 +256,9 @@ def merge_operators( operators: The arrays representing the unitaries to be merged. targets: The corresponding target qubits. controls: The corresponding control qubits. + Returns: A tuple of merged operators, targets and controls. - """ if len(operators) < 2: return operators, targets, controls @@ -105,30 +280,46 @@ def merge_operators( return merged_operators[::-1], merged_targets[::-1], merged_controls[::-1] +@singledispatch def apply_gate( - state: State, + state: Any, gate: Primitive | Iterable[Primitive], values: dict[str, float] = dict(), op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, -) -> State: - """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. - Arguments: - state: State to operate on. - gate: Gate(s) to apply. - values: A dictionary with parameter values. - op_type: The type of operation to perform: Unitary, Dagger or Jacobian. - group_gates: Group gates together which are acting on the same qubit. - merge_ops: Attempt to merge operators acting on the same qubit. +) -> Any: + raise NotImplementedError("apply_gate is not implemented") + + +def prepare_sequence_reduce( + gate: Primitive | Iterable[Primitive], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> tuple[tuple[Array, ...], tuple, tuple, list[NoiseProtocol]]: + """Prepare the tuples to be used when applying operations. + + Args: + gate (Primitive | Iterable[Primitive]): Gate(s) to apply. + values (dict[str, float], optional): A dictionary with parameter values. + Defaults to dict(). + op_type (OperationType, optional): The type of operation to perform: Unitary, Dagger or Jacobian. + Defaults to OperationType.UNITARY. + group_gates (bool, optional): Group gates together which are acting on the same qubit. + Defaults to False. Returns: - State after applying 'gate'. + tuple[tuple[Array, ...], tuple, tuple, list[NoiseProtocol]]: Operators, targets, + controls and noise. """ - operator: Tuple[Array, ...] + operator: tuple[Array, ...] + noise = list() if isinstance(gate, Primitive): operator_fn = getattr(gate, op_type) operator, target, control = (operator_fn(values),), gate.target, gate.control + noise += [gate.noise] else: if group_gates: gate = group_by_index(gate) @@ -137,8 +328,82 @@ def apply_gate( control = reduce(add, [g.control for g in gate]) if merge_ops: operator, target, control = merge_operators(operator, target, control) - return reduce( - lambda state, gate: apply_operator(state, *gate), - zip(operator, target, control), + noise = [g.noise for g in gate] + + return operator, target, control, noise + + +@apply_gate.register +def _( + state: Array, + gate: Union[Primitive, Iterable[Primitive]], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> State: + """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. + Arguments: + state: Array or DensityMatrix to operate on. + gate: Gate(s) to apply. + values: A dictionary with parameter values. + op_type: The type of operation to perform: Unitary, Dagger or Jacobian. + group_gates: Group gates together which are acting on the same qubit. + merge_ops: Attempt to merge operators acting on the same qubit. + + Returns: + Array or density matrix after applying 'gate'. + """ + operator, target, control, noise = prepare_sequence_reduce( + gate, values, op_type, group_gates, merge_ops + ) + + # faster way to check has_noise + has_noise = noise != [None] * len(noise) + if has_noise: + state = density_mat(state) + + output_state = reduce( + lambda state, gate: apply_operator_with_noise(state, *gate), + zip(operator, target, control, noise), + state, + ) + else: + output_state = reduce( + lambda state, gate: apply_operator(state, *gate), + zip(operator, target, control), + state, + ) + return output_state + + +@apply_gate.register +def _( + state: DensityMatrix, + gate: Union[Primitive, Iterable[Primitive]], + values: dict[str, float] = dict(), + op_type: OperationType = OperationType.UNITARY, + group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution + merge_ops: bool = True, +) -> DensityMatrix: + """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. + Arguments: + state: Array or DensityMatrix to operate on. + gate: Gate(s) to apply. + values: A dictionary with parameter values. + op_type: The type of operation to perform: Unitary, Dagger or Jacobian. + group_gates: Group gates together which are acting on the same qubit. + merge_ops: Attempt to merge operators acting on the same qubit. + + Returns: + Array or density matrix after applying 'gate'. + """ + operator, target, control, noise = prepare_sequence_reduce( + gate, values, op_type, group_gates, merge_ops + ) + output_state = reduce( + lambda state, gate: apply_operator_with_noise(state, *gate), + zip(operator, target, control, noise), state, ) + return output_state diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ae4bb03..a398106 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -1,20 +1,16 @@ from __future__ import annotations -from collections import Counter from dataclasses import dataclass, field from typing import Any, Callable from uuid import uuid4 -import jax -import jax.numpy as jnp from jax import Array from jax.tree_util import register_pytree_node_class -from horqrux.adjoint import ad_expectation, adjoint_expectation from horqrux.apply import apply_gate from horqrux.parametric import RX, RY, Parametric from horqrux.primitive import NOT, Primitive -from horqrux.utils import DiffMode, zero_state +from horqrux.utils import zero_state @register_pytree_node_class @@ -113,48 +109,3 @@ def hea( gates += ops return gates - - -def expectation( - state: Array, - gates: list[Primitive], - observable: list[Primitive], - values: dict[str, float], - diff_mode: DiffMode | str = DiffMode.AD, -) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. - """ - if diff_mode == DiffMode.AD: - return ad_expectation(state, gates, observable, values) - else: - return adjoint_expectation(state, gates, observable, values) - - -def sample( - state: Array, - gates: list[Primitive], - values: dict[str, float] = dict(), - n_shots: int = 1000, -) -> Counter: - if n_shots < 1: - raise ValueError("You can only call sample with n_shots>0.") - - wf = apply_gate(state, gates, values) - probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() - key = jax.random.PRNGKey(0) - n_qubits = len(state.shape) - # JAX handles pseudo random number generation by tracking an explicit state via a random key - # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html - samples = jax.vmap( - lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) - )(jax.random.split(key, n_shots)) - - return Counter( - { - format(k, "0{}b".format(n_qubits)): count.item() - for k, count in enumerate(jnp.bincount(samples)) - if count > 0 - } - ) diff --git a/horqrux/noise.py b/horqrux/noise.py new file mode 100644 index 0000000..473dd85 --- /dev/null +++ b/horqrux/noise.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Iterable, Union + +from jax import Array +from jax.tree_util import register_pytree_node_class + +from .utils import ( + StrEnum, +) +from .utils_noise import ( + AmplitudeDamping, + BitFlip, + Depolarizing, + GeneralizedAmplitudeDamping, + PauliChannel, + PhaseDamping, + PhaseFlip, +) + + +class DigitalNoiseType(StrEnum): + BITFLIP = "BitFlip" + PHASEFLIP = "PhaseFlip" + DEPOLARIZING = "Depolarizing" + PAULI_CHANNEL = "PauliChannel" + AMPLITUDE_DAMPING = "AmplitudeDamping" + PHASE_DAMPING = "PhaseDamping" + GENERALIZED_AMPLITUDE_DAMPING = "GeneralizedAmplitudeDamping" + + +PROTOCOL_TO_KRAUS_FN: dict[str, Callable] = { + "BitFlip": BitFlip, + "PhaseFlip": PhaseFlip, + "Depolarizing": Depolarizing, + "PauliChannel": PauliChannel, + "AmplitudeDamping": AmplitudeDamping, + "PhaseDamping": PhaseDamping, + "GeneralizedAmplitudeDamping": GeneralizedAmplitudeDamping, +} + + +@register_pytree_node_class +@dataclass +class DigitalNoiseInstance: + type: DigitalNoiseType + error_probability: tuple[float, ...] | float + + def __iter__(self) -> Iterable: + return iter((self.kraus, self.error_probability)) + + def tree_flatten( + self, + ) -> tuple[tuple, tuple[DigitalNoiseType, tuple[float, ...] | float]]: + children = () + aux_data = (self.type, self.error_probability) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + @property + def kraus(self) -> tuple[Array, ...]: + kraus_fn: Callable[..., tuple[Array, ...]] = PROTOCOL_TO_KRAUS_FN[self.type] + return kraus_fn(error_probability=self.error_probability) + + def __repr__(self) -> str: + return self.type + f"(p={self.error_probability})" + + +NoiseProtocol = Union[tuple[DigitalNoiseInstance, ...], None] diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 80b925b..0f37914 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, Tuple +from typing import Any, Iterable import jax.numpy as jnp from jax import Array @@ -9,6 +9,7 @@ from ._misc import default_complex_dtype from .matrices import OPERATIONS_DICT +from .noise import NoiseProtocol from .primitive import Primitive from .utils import ( ControlQubits, @@ -30,6 +31,7 @@ class Parametric(Primitive): generator_name: str target: QubitSupport control: QubitSupport + noise: NoiseProtocol = None param: str | float = "" def __post_init__(self) -> None: @@ -43,18 +45,21 @@ def parse_val(values: dict[str, float] = dict()) -> float: self.parse_values = parse_dict if isinstance(self.param, str) else parse_val - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override] + def tree_flatten( # type: ignore[override] + self, + ) -> tuple[tuple, tuple[str, tuple, tuple, NoiseProtocol, str | float]]: children = () aux_data = ( self.generator_name, self.target[0], self.control[0], + self.noise, self.param, ) return (children, aux_data) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.param)) + return iter((self.generator_name, self.target, self.control, self.noise, self.param)) @classmethod def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: @@ -75,46 +80,64 @@ def __repr__(self) -> str: return self.name + f"(target={self.target}, control={self.control}, param={self.param})" -def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RX( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: NoiseProtocol = None, +) -> Parametric: """RX gate. Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("X", target, control, param) + return Parametric("X", target, control, noise, param) -def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RY( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: NoiseProtocol = None, +) -> Parametric: """RY gate. Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("Y", target, control, param) + return Parametric("Y", target, control, noise, param) -def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def RZ( + param: float | str, + target: TargetQubits, + control: ControlQubits = (None,), + noise: NoiseProtocol = None, +) -> Parametric: """RZ gate. Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return Parametric("Z", target, control, param) + return Parametric("Z", target, control, noise, param) class _PHASE(Parametric): @@ -134,16 +157,22 @@ def name(self) -> str: return "C" + base_name if is_controlled(self.control) else base_name -def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: +def PHASE( + param: float, + target: TargetQubits, + control: ControlQubits = (None,), + noise: NoiseProtocol = None, +) -> Parametric: """Phase gate. Arguments: param: Parameter denoting the Rotational angle. - target: Tuple of target qubits denoted as ints. + target: tuple of target qubits denoted as ints. control: Optional tuple of control qubits denoted as ints. + noise: The noise instance. Defaults to None. Returns: Parametric: A Parametric gate object. """ - return _PHASE("I", target, control, param) + return _PHASE("I", target, control, noise, param) diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 7156117..5f5aedf 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,13 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, Tuple, Union +from typing import Any, Iterable, Union import numpy as np from jax import Array from jax.tree_util import register_pytree_node_class from .matrices import OPERATIONS_DICT +from .noise import NoiseProtocol from .utils import ( ControlQubits, QubitSupport, @@ -27,11 +28,12 @@ class Primitive: generator_name: str target: QubitSupport control: QubitSupport + noise: NoiseProtocol = None @staticmethod def parse_idx( - idx: Tuple, - ) -> Tuple: + idx: tuple, + ) -> tuple: if isinstance(idx, (int, np.int64)): return ((idx,),) elif isinstance(idx, tuple): @@ -47,11 +49,11 @@ def __post_init__(self) -> None: self.control = Primitive.parse_idx(self.control) def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control)) + return iter((self.generator_name, self.target, self.control, self.noise)) - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: + def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, NoiseProtocol]]: children = () - aux_data = (self.generator_name, self.target[0], self.control[0]) + aux_data = (self.generator_name, self.target[0], self.control[0], self.noise) return (children, aux_data) @classmethod @@ -68,6 +70,13 @@ def dagger(self, values: dict[str, float] = dict()) -> Array: def name(self) -> str: return "C" + self.generator_name if is_controlled(self.control) else self.generator_name + @property + def n_qubits(self) -> int: + n_qubits = len(self.target) + if self.control[0] is not None: + n_qubits += len(self.control) + return n_qubits + def __repr__(self) -> str: return self.name + f"(target={self.target}, control={self.control})" @@ -75,43 +84,53 @@ def __repr__(self) -> str: GateSequence = Union[Primitive, Iterable[Primitive]] -def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def I( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. Example usage: I(1) represents the instruction to apply I to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("I", target, control) + return Primitive("I", target, control, noise) + +def X( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: + """The definition for the X gate. -def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: - """X gate. This function returns an instance of 'Primitive' and does *not* apply the gate. + This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. Example usage: X(1) represents the instruction to apply X to qubit 1. Example usage controlled: X(1, 0) represents the instruction to apply CX / CNOT to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("X", target, control) + return Primitive("X", target, control, noise) NOT = X -def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def Y( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """Y gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -119,16 +138,19 @@ def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Example usage controlled: Y(1, 0) represents the instruction to apply CY to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("Y", target, control) + return Primitive("Y", target, control, noise) -def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def Z( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """Z gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -136,67 +158,79 @@ def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Example usage controlled: Z(1, 0) represents the instruction to apply CZ to qubit 1 with controlled qubit 0. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("Z", target, control) + return Primitive("Z", target, control, noise) -def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def H( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """H/ Hadamard gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. Example usage: H(1) represents the instruction to apply Hadamard to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("H", target, control) + return Primitive("H", target, control, noise) -def S(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def S( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """S gate or constant phase gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. Example usage: S(1) represents the instruction to apply S to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("S", target, control) + return Primitive("S", target, control, noise) -def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def T( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """T gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. Example usage: T(1) represents the instruction to apply Hadamard to qubit 1. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints indicating the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("T", target, control) + return Primitive("T", target, control, noise) # Multi (target) qubit gates -def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: +def SWAP( + target: TargetQubits, control: ControlQubits = (None,), noise: NoiseProtocol = None +) -> Primitive: """SWAP gate. By providing a control, it turns into a controlled gate (Fredkin gate), use None for no control qubits. @@ -204,14 +238,15 @@ def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: Example usage controlled: SWAP(((0, 1), ), ((2, ))) swaps qubits 0 and 1 with controlled bit 2. Args: - target: Tuple of ints describing the qubits to apply to. + target: tuple of ints describing the qubits to apply to. control: Optional tuple of ints or None describing the control qubits. Defaults to (None,). + noise: The noise instance. Defaults to None. Returns: A Primitive instance. """ - return Primitive("SWAP", target, control) + return Primitive("SWAP", target, control, noise) def SQSWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: diff --git a/horqrux/shots.py b/horqrux/shots.py index 4383100..38a7c14 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import partial, reduce +from functools import partial, reduce, singledispatch from typing import Any import jax @@ -10,10 +10,14 @@ from horqrux.apply import apply_gate from horqrux.primitive import GateSequence, Primitive -from horqrux.utils import none_like +from horqrux.utils import DensityMatrix, State, none_like, num_qubits -def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: +def to_matrix( + observable: Primitive, + n_qubits: int, + values: dict[str, float], +) -> Array: """For finite shot sampling we need to calculate the eigenvalues/vectors of an observable. This helper function takes an observable and system size (n_qubits) and returns the overall action of the observable on the whole @@ -25,7 +29,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: observable.control == observable.parse_idx(none_like(observable.target)), "Controlled gates cannot be promoted from observables to operations on the whole state vector", ) - unitary = observable.unitary() + unitary = observable.unitary(values=values) target = observable.target[0][0] identity = jnp.eye(2, dtype=unitary.dtype) ops = [identity for _ in range(n_qubits)] @@ -33,9 +37,89 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) +@singledispatch +def eigen_probabilities(state: Any, eigvecs: Array) -> Array: + """Obtain the probabilities using an input state and the eigenvectors decomposition + of an observable. + + Args: + state (Any): Input. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + raise NotImplementedError( + f"eigen_probabilities is not implemented for the state type {type(state)}." + ) + + +@eigen_probabilities.register +def _(state: Array, eigvecs: Array) -> Array: + """Obtain the probabilities using an input quantum state vector + and the eigenvectors decomposition + of an observable. + + Args: + state (Array): Input array. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + return jnp.abs(inner_prod) ** 2 + + +@eigen_probabilities.register +def _(state: DensityMatrix, eigvecs: Array) -> Array: + """Obtain the probabilities using an input quantum density matrix + and the eigenvectors decomposition + of an observable. + + Args: + state (DensityMatrix): Input density matrix. + eigvecs (Array): Eigenvectors of the observables. + + Returns: + Array: The probabilities. + """ + mat_prob = jnp.conjugate(eigvecs.T) @ state.array @ eigvecs + return mat_prob.diagonal().real + + +def eigen_sample( + state: State, + observables: list[Primitive], + values: dict[str, float], + n_qubits: int, + n_shots: int, + key: Any = jax.random.PRNGKey(0), +) -> Array: + """Sample eigenvalues of observable given the probability distribution + defined by applying the eigenvectors to the state. + + Args: + state (State): Input state or density matrix. + observables (list[Primitive]): list of observables. + values (dict[str, float]): Parameter values. + n_qubits (int): Number of qubits + n_shots (int): Number of samples + key (Any, optional): Random seed key. Defaults to jax.random.PRNGKey(0). + + Returns: + Array: Sampled eigenvalues. + """ + mat_obs = [to_matrix(observable, n_qubits, values) for observable in observables] + eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] + eigvecs, eigvals = align_eigenvectors(eigs) + probs = eigen_probabilities(state, eigvecs) + return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + + @partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) def finite_shots_fwd( - state: Array, + state: State, gates: GateSequence, observables: list[Primitive], values: dict[str, float], @@ -46,14 +130,12 @@ def finite_shots_fwd( Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - state = apply_gate(state, gates, values) - n_qubits = len(state.shape) - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] - eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] - eigvecs, eigvals = align_eigenvectors(eigs) - inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) - probs = jnp.abs(inner_prod) ** 2 - return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + output_gates = apply_gate(state, gates, values) + n_qubits = num_qubits(output_gates) + if isinstance(state, DensityMatrix): + d = 2**n_qubits + output_gates.array = output_gates.array.reshape((d, d)) + return eigen_sample(output_gates, observables, values, n_qubits, n_shots, key) def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: diff --git a/horqrux/utils.py b/horqrux/utils.py index 2cc82f4..1b04833 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -1,12 +1,16 @@ from __future__ import annotations +from collections import Counter +from dataclasses import dataclass from enum import Enum -from typing import Any, Iterable, Tuple, Union +from functools import singledispatch +from typing import Any, Iterable, Union import jax import jax.numpy as jnp import numpy as np from jax import Array +from jax.tree_util import register_pytree_node_class from jax.typing import ArrayLike from numpy import log2 @@ -14,13 +18,73 @@ default_dtype = default_complex_dtype() -State = ArrayLike -QubitSupport = Tuple[Any, ...] -ControlQubits = Tuple[Union[None, Tuple[int, ...]], ...] -TargetQubits = Tuple[Tuple[int, ...], ...] +QubitSupport = tuple[Any, ...] +ControlQubits = tuple[Union[None, tuple[int, ...]], ...] +TargetQubits = tuple[tuple[int, ...], ...] ATOL = 1e-014 +@register_pytree_node_class +@dataclass +class DensityMatrix: + """Dataclass to identify density matrices from states.""" + + array: Array + + def tree_flatten(self) -> tuple[tuple, tuple[Array]]: + children = () + aux_data = (self.array,) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + +State = Union[ArrayLike, DensityMatrix] + + +def density_mat(state: ArrayLike) -> DensityMatrix: + """Convert state to density matrix + + Args: + state (ArrayLike): Input state. + + Returns: + DensityMatrix: Density matrix representation. + """ + # Expand dimensions to enable broadcasting + if isinstance(state, DensityMatrix): + return state + ket = jnp.expand_dims(state, axis=tuple(range(state.ndim, 2 * state.ndim))) + bra = jnp.conj(jnp.expand_dims(state, axis=tuple(range(state.ndim)))) + return DensityMatrix(ket * bra) + + +def permute_basis(operator: Array, qubit_support: tuple, inv: bool = False) -> Array: + """Takes an operator tensor and permutes the rows and + columns according to the order of the qubit support. + + Args: + operator (Tensor): Operator to permute over. + qubit_support (tuple): Qubit support. + inv (bool): Applies the inverse permutation instead. + + Returns: + Tensor: Permuted operator. + """ + ordered_support = np.argsort(qubit_support) + ranked_support = np.argsort(ordered_support) + n_qubits = len(qubit_support) + if all(a == b for a, b in zip(ranked_support, tuple(range(n_qubits)))): + return operator + + perm = tuple(ranked_support) + tuple(ranked_support + n_qubits) + if inv: + perm = np.argsort(perm) + return jax.lax.transpose(operator, perm) + + class StrEnum(str, Enum): def __str__(self) -> str: """Used when dumping enum fields in a schema.""" @@ -57,7 +121,18 @@ class ForwardMode(StrEnum): def _dagger(operator: Array) -> Array: - return jnp.conjugate(operator.T) + # If the operator is a tensor with repeated 2D axes + if operator.ndim > 2: + # Conjugate and swap the last two axes + conjugated = operator.conj() + + # Create the transpose axes: swap pairs of indices + half = operator.ndim // 2 + axes = tuple(range(half, operator.ndim)) + tuple(range(half)) + return jnp.transpose(conjugated, axes) + else: + # For standard matrices, use conjugate transpose + return jnp.conjugate(operator.T) def _unitary(generator: Array, theta: float) -> Array: @@ -101,14 +176,14 @@ def zero_state(n_qubits: int) -> Array: return product_state("0" * n_qubits) -def none_like(x: Iterable) -> Tuple[None, ...]: +def none_like(x: Iterable) -> tuple[None, ...]: """Generates a tuple of Nones with equal length to x. Useful for gates with multiple targets but no control. Args: x (Iterable): Iterable to be mimicked. Returns: - Tuple[None, ...]: Tuple of Nones of length x. + tuple[None, ...]: tuple of Nones of length x. """ return tuple(map(lambda _: None, x)) @@ -156,7 +231,7 @@ def uniform_state( return state.reshape([2] * n_qubits) -def is_controlled(qubit_support: Tuple[int | None, ...] | int | None) -> bool: +def is_controlled(qubit_support: Union[tuple[Union[int, None], ...], int, None]) -> bool: if isinstance(qubit_support, int): return True elif isinstance(qubit_support, tuple): @@ -180,3 +255,70 @@ def _normalize(wf: Array) -> Array: def is_normalized(state: Array) -> bool: return equivalent_state(state, state) + + +def sample_from_probs(probs: Array, n_qubits: int, n_shots: int) -> Counter: + key = jax.random.PRNGKey(0) + + # JAX handles pseudo random number generation by tracking an explicit state via a random key + # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html + samples = jax.vmap( + lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) + )(jax.random.split(key, n_shots)) + + return Counter( + { + format(k, f"0{n_qubits}b"): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) + + +@singledispatch +def probabilities(state: Any) -> Array: + """Extract probabilities from state or density matrix. + + Args: + state (Array): Input array. + + Raises: + NotImplementedError: If not implemented for given types. + + Returns: + Array: Vector of probabilities. + """ + raise NotImplementedError(f"Probabilities is not implemented for the input type {type(state)}.") + + +@probabilities.register +def _(state: Array) -> Array: + return jnp.abs(jnp.float_power(state, 2.0)).ravel() + + +@probabilities.register +def _(state: DensityMatrix) -> Array: + return jnp.diagonal(state.array).real + + +@singledispatch +def num_qubits(state: Any) -> int: + """Returns the number of qubits of a state. + + Args: + state (Any): state. + + Returns: + int: Number of qubits. + """ + raise NotImplementedError(f"num_qubits is not implemented for the state type {type(state)}.") + + +@num_qubits.register +def _(state: Array) -> int: + return len(state.shape) + + +@num_qubits.register +def _(state: DensityMatrix) -> int: + return len(state.array.shape) // 2 diff --git a/horqrux/utils_noise.py b/horqrux/utils_noise.py new file mode 100644 index 0000000..ce10ff2 --- /dev/null +++ b/horqrux/utils_noise.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import jax.numpy as jnp +from jax import Array + +from .matrices import OPERATIONS_DICT + + +def BitFlip(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the BitFlip gate. + + The bit flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} + + Args: + error_probability (float): The probability of a bit flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + if (error_probability > 1.0) or (error_probability < 0.0): + raise ValueError(f"The 'error_probability' value is incorrect. Got {error_probability}.") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["X"] + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def PhaseFlip(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the PhaseFlip gate + + The phase flip channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + p Z \\rho Z^{\\dagger} + + Args: + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if (error_probability > 1.0) or (error_probability < 0.0): + raise ValueError("The error_probability value is not a correct probability") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def Depolarizing(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the Depolarizing gate. + + The depolarizing channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-p) \\rho + + p/3 X \\rho X^{\\dagger} + + p/3 Y \\rho Y^{\\dagger} + + p/3 Z \\rho Z^{\\dagger} + + Args: + error_probability (float): The probability of phase flip error. + + Raises: + ValueError: If the error_probability value is not a float. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if (error_probability > 1.0) or (error_probability < 0.0): + raise ValueError(f"The 'error_probability' value is incorrect. Got {error_probability}.") + K0: Array = jnp.sqrt(1.0 - error_probability) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(error_probability / 3.0) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus + + +def PauliChannel(error_probability: tuple[float, ...]) -> tuple[Array, ...]: + """ + Initialize the PauliChannel gate. + + The pauli channel is defined as: + + .. math:: + \\rho \\Rightarrow (1-px-py-pz) \\rho + + px X \\rho X^{\\dagger} + + py Y \\rho Y^{\\dagger} + + pz Z \\rho Z^{\\dagger} + + Args: + error_probability (tuple[float, ...] | float): tuple containing probabilities + of X, Y, and Z errors. + + Raises: + ValueError: If the probabilities values do not sum up to 1. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + sum_prob = sum(error_probability) + if sum_prob > 1.0: + raise ValueError("The sum of probabilities can't be greater than 1.0") + if any([probability > 1.0 or probability < 0.0 for probability in error_probability]): + raise ValueError(f"The 'error_probability' values are incorrect. Got {error_probability}.") + px, py, pz = ( + error_probability[0], + error_probability[1], + error_probability[2], + ) + + K0: Array = jnp.sqrt(1.0 - (px + py + pz)) * OPERATIONS_DICT["I"] + K1: Array = jnp.sqrt(px) * OPERATIONS_DICT["X"] + K2: Array = jnp.sqrt(py) * OPERATIONS_DICT["Y"] + K3: Array = jnp.sqrt(pz) * OPERATIONS_DICT["Z"] + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus + + +def AmplitudeDamping(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the AmplitudeDamping gate. + + The amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - error_probability)]] + K1 = [[0, sqrt(error_probability)], [0, 0]] + + Args: + error_probability (float): The damping rate, indicating the probability of amplitude loss. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if (error_probability > 1.0) or (error_probability < 0.0): + raise ValueError(f"The 'error_probability' value is incorrect. Got {error_probability}.") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - error_probability)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, jnp.sqrt(error_probability)], [0, 0]], dtype=jnp.complex128) + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def PhaseDamping(error_probability: float) -> tuple[Array, ...]: + """ + Initialize the PhaseDamping gate. + + The phase damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + with: + + .. code-block:: python + + K0 = [[1, 0], [0, sqrt(1 - error_probability)]] + K1 = [[0, 0], [0, sqrt(error_probability)]] + + Args: + error_probability (float): The damping rate, indicating the probability of phase damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + if (error_probability > 1.0) or (error_probability < 0.0): + raise ValueError(f"The 'error_probability' value is incorrect. Got {error_probability}.") + K0: Array = jnp.array([[1, 0], [0, jnp.sqrt(1 - error_probability)]], dtype=jnp.complex128) + K1: Array = jnp.array([[0, 0], [0, jnp.sqrt(error_probability)]], dtype=jnp.complex128) + kraus: tuple[Array, ...] = (K0, K1) + return kraus + + +def GeneralizedAmplitudeDamping(error_probability: tuple[float, ...]) -> tuple[Array, ...]: + """ + Initialize the GeneralizeAmplitudeDamping gate. + + The generalize amplitude damping channel is defined as: + + .. math:: + \\rho \\Rightarrow K_0 \\rho K_0^{\\dagger} + K_1 \\rho K_1^{\\dagger} + + K_2 \\rho K_2^{\\dagger} + K_3 \\rho K_3^{\\dagger} + + with: + + .. code-block:: python + + K0 = sqrt(p) * [[1, 0], [0, sqrt(1 - rate)]] + K1 = sqrt(p) * [[0, sqrt(rate)], [0, 0]] + K2 = sqrt(1-p) * [[sqrt(1 - rate), 0], [0, 1]] + K3 = sqrt(1-p) * [[0, 0], [sqrt(rate), 0]] + + Args: + error_probability (tuple[float, ...] | float): The first float must be the probability + of amplitude damping error, and the second float is the damping rate, indicating + the probability of generalized amplitude damping. + + Raises: + ValueError: If the damping rate is not a correct probability. + + Returns: + tuple[Array, ...]: Kraus operators for this protocol. + """ + + probability = error_probability[0] + rate = error_probability[1] + if (probability > 1.0) or (probability < 0.0): + raise ValueError( + f"The first value of 'error_probability' value is incorrect. Got {probability}." + ) + if (rate > 1.0) or (rate < 0.0): + raise ValueError(f"The second value of 'error_probability' value is incorrect. Got {rate}.") + + K0: Array = jnp.sqrt(probability) * jnp.array( + [[1, 0], [0, jnp.sqrt(1 - rate)]], dtype=jnp.complex128 + ) + K1: Array = jnp.sqrt(probability) * jnp.array( + [[0, jnp.sqrt(rate)], [0, 0]], dtype=jnp.complex128 + ) + K2: Array = jnp.sqrt(1.0 - probability) * jnp.array( + [[jnp.sqrt(1.0 - rate), 0], [0, 1]], dtype=jnp.complex128 + ) + K3: Array = jnp.sqrt(1.0 - probability) * jnp.array( + [[0, 0], [jnp.sqrt(rate), 0]], dtype=jnp.complex128 + ) + kraus: tuple[Array, ...] = (K0, K1, K2, K3) + return kraus diff --git a/mkdocs.yml b/mkdocs.yml index e3ac86a..0f214f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -6,6 +6,8 @@ nav: - horqrux in a nutshell: index.md - Contribute: CONTRIBUTING.md - Code of Conduct: CODE_OF_CONDUCT.md + - Advanced Features: + - Noisy simulation: noise.md theme: name: material diff --git a/pyproject.toml b/pyproject.toml index 13dcd50..16819bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "horqrux" -description = "Jax-based quantum state vector simulator." +description = "Jax-based quantum state vector and noisy simulator." authors = [ { name = "Gert-Jan Both" , email = "gert-jan.both@pasqal.com" }, { name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" }, @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.8,<3.13" license = {text = "Apache 2.0"} -version = "0.6.2" +version = "0.7.0" classifiers=[ "License :: Other/Proprietary License", diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 27f0164..2df8bdd 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -4,8 +4,7 @@ import numpy as np from jax import Array, grad -from horqrux import random_state -from horqrux.circuit import expectation +from horqrux import expectation, random_state from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z from horqrux.utils import DiffMode @@ -27,7 +26,7 @@ def test_gradcheck() -> None: state = random_state(MAX_QUBITS) def exp_fn(values: dict, diff_mode: DiffMode = "ad") -> Array: - return expectation(state, ops, observable, values, diff_mode) + return expectation(state, ops, observable, values, diff_mode).item() grads_adjoint = grad(exp_fn)(values, "adjoint") grad_ad = grad(exp_fn)(values) diff --git a/tests/test_gates.py b/tests/test_gates.py index 4c44cd3..afb737d 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -10,7 +10,7 @@ from horqrux.apply import apply_gate, apply_operator from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z -from horqrux.utils import equivalent_state, product_state, random_state +from horqrux.utils import density_mat, equivalent_state, product_state, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -22,11 +22,21 @@ def test_primitive(gate_fn: Callable) -> None: target = np.random.randint(0, MAX_QUBITS) gate = gate_fn(target) orig_state = random_state(MAX_QUBITS) + assert len(orig_state) == 2 state = apply_gate(orig_state, gate) assert jnp.allclose( apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(), + gate.target[0], + gate.control[0], + ) + assert jnp.allclose(dm.array, density_mat(state).array) + @pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) def test_controlled_primitive(gate_fn: Callable) -> None: @@ -41,6 +51,15 @@ def test_controlled_primitive(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(), + gate.target[0], + gate.control[0], + ) + assert jnp.allclose(dm.array, density_mat(state).array) + @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) def test_parametric(gate_fn: Callable) -> None: @@ -53,6 +72,15 @@ def test_parametric(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(values), + gate.target[0], + gate.control[0], + ) + assert jnp.allclose(dm.array, density_mat(state).array) + @pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) def test_controlled_parametric(gate_fn: Callable) -> None: @@ -68,6 +96,15 @@ def test_controlled_parametric(gate_fn: Callable) -> None: apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state ) + # test density matrix is similar to pure state + dm = apply_operator( + density_mat(orig_state), + gate.unitary(values), + gate.target[0], + gate.control[0], + ) + assert jnp.allclose(dm.array, density_mat(state).array) + @pytest.mark.parametrize( ["bitstring", "expected_state"], diff --git a/tests/test_noise.py b/tests/test_noise.py new file mode 100644 index 0000000..5d6d6b7 --- /dev/null +++ b/tests/test_noise.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from typing import Callable + +import jax.numpy as jnp +import numpy as np +import pytest + +from horqrux.api import expectation, run, sample +from horqrux.apply import apply_gate +from horqrux.noise import DigitalNoiseInstance, DigitalNoiseType +from horqrux.parametric import PHASE, RX, RY, RZ +from horqrux.primitive import NOT, H, I, S, T, X, Y, Z +from horqrux.utils import ForwardMode, density_mat, product_state, random_state + +MAX_QUBITS = 7 +PARAMETRIC_GATES = (RX, RY, RZ, PHASE) +PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) + +NOISE_single_prob = ( + DigitalNoiseType.BITFLIP, + DigitalNoiseType.PHASEFLIP, + DigitalNoiseType.DEPOLARIZING, + DigitalNoiseType.AMPLITUDE_DAMPING, + DigitalNoiseType.PHASE_DAMPING, +) +ALL_NOISES = list(DigitalNoiseType) + + +def noise_instance(noise_type: DigitalNoiseType) -> DigitalNoiseInstance: + if noise_type in NOISE_single_prob: + errors = 0.1 + elif noise_type == DigitalNoiseType.PAULI_CHANNEL: + errors = (0.4, 0.5, 0.1) + else: + errors = (0.2, 0.8) + + return DigitalNoiseInstance(noise_type, error_probability=errors) + + +@pytest.mark.parametrize("noise_type", NOISE_single_prob) +def test_error_prob(noise_type: DigitalNoiseType): + with pytest.raises(ValueError): + noise = DigitalNoiseInstance(noise_type, error_probability=-0.5).kraus + with pytest.raises(ValueError): + noise = DigitalNoiseInstance(noise_type, error_probability=1.1).kraus + + +def test_error_paulichannel(): + with pytest.raises(ValueError): + noise = DigitalNoiseInstance( + DigitalNoiseType.PAULI_CHANNEL, error_probability=(0.4, 0.5, 1.1) + ).kraus + + for p in range(3): + probas = [1.0 / 3.0] * 3 + probas[p] = -0.1 + with pytest.raises(ValueError): + noise = DigitalNoiseInstance( + DigitalNoiseType.PAULI_CHANNEL, error_probability=probas + ).kraus + + probas = [0.0] * 3 + probas[p] = 1.1 + with pytest.raises(ValueError): + noise = DigitalNoiseInstance( + DigitalNoiseType.PAULI_CHANNEL, error_probability=probas + ).kraus + + +def test_error_prob_GeneralizedAmplitudeDamping(): + for p in range(2): + probas = [1.0 / 2.0] * 2 + probas[p] = -0.1 + with pytest.raises(ValueError): + noise = DigitalNoiseInstance( + DigitalNoiseType.GENERALIZED_AMPLITUDE_DAMPING, error_probability=probas + ).kraus + + probas = [0.0] * 2 + probas[p] = 1.1 + with pytest.raises(ValueError): + noise = DigitalNoiseInstance( + DigitalNoiseType.GENERALIZED_AMPLITUDE_DAMPING, error_probability=probas + ).kraus + + +@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES) +@pytest.mark.parametrize("noise_type", ALL_NOISES) +def test_noisy_primitive(gate_fn: Callable, noise_type: DigitalNoiseType) -> None: + target = np.random.randint(0, MAX_QUBITS) + noise = noise_instance(noise_type) + + noisy_gate = gate_fn(target, noise=(noise,)) + assert len(noisy_gate.noise) == 1 + + dm_shape_len = 2 * MAX_QUBITS + + orig_state = random_state(MAX_QUBITS) + output_dm = apply_gate(orig_state, noisy_gate) + + # check output is a density matrix + assert len(output_dm.array.shape) == dm_shape_len + + orig_dm = density_mat(orig_state) + assert len(orig_dm.array.shape) == dm_shape_len + output_dm2 = apply_gate( + orig_dm, + noisy_gate, + ) + assert jnp.allclose(output_dm2.array, output_dm.array) + + perfect_gate = gate_fn(target) + perfect_output = density_mat(apply_gate(orig_state, perfect_gate)) + assert not jnp.allclose(perfect_output.array, output_dm.array) + + +@pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES) +@pytest.mark.parametrize("noise_type", ALL_NOISES) +def test_noisy_parametric(gate_fn: Callable, noise_type: DigitalNoiseType) -> None: + target = np.random.randint(0, MAX_QUBITS) + noise = noise_instance(noise_type) + noisy_gate = gate_fn("theta", target, noise=(noise,)) + values = {"theta": np.random.uniform(0.1, 2 * np.pi)} + orig_state = random_state(MAX_QUBITS) + + dm_shape_len = 2 * MAX_QUBITS + + output_dm = apply_gate(orig_state, noisy_gate, values) + # check output is a density matrix + assert len(output_dm.array.shape) == dm_shape_len + + orig_dm = density_mat(orig_state) + assert len(orig_dm.array.shape) == dm_shape_len + + output_dm2 = apply_gate( + orig_dm, + noisy_gate, + values, + ) + assert jnp.allclose(output_dm2.array, output_dm.array) + + perfect_gate = gate_fn("theta", target) + perfect_output = density_mat(apply_gate(orig_state, perfect_gate, values)) + assert not jnp.allclose(perfect_output.array, output_dm.array) + + +def simple_depolarizing_test() -> None: + noise = (DigitalNoiseInstance(DigitalNoiseType.DEPOLARIZING, 0.1),) + ops = [X(0, noise=noise), X(1)] + state = product_state("00") + state_output = run(ops, state) + + # test run + assert jnp.allclose( + state_output, + jnp.array( + [ + [ + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + [[0.0 - 0.0j, 0.06666667 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + ], + [ + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.0 - 0.0j]], + [[0.0 - 0.0j, 0.0 - 0.0j], [0.0 - 0.0j, 0.93333333 - 0.0j]], + ], + ], + dtype=jnp.complex128, + ), + ) + + # test sampling + dm_state = density_mat(state) + sampling_output = sample( + dm_state, + ops, + ) + assert "11" in sampling_output.keys() + assert "01" in sampling_output.keys() + + # test expectation + exp_dm = expectation(dm_state, ops, [Z(0)], {}) + assert jnp.allclose(exp_dm, jnp.array([-0.86666667], dtype=jnp.float64)) + + # test shots expectation + exp_dm_shots = expectation( + dm_state, ops, [Z(0)], {}, forward_mode=ForwardMode.SHOTS, n_shots=1000 + ) + assert jnp.allclose(exp_dm, exp_dm_shots, atol=1e-02) diff --git a/tests/test_shots.py b/tests/test_shots.py index c98062d..5660610 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -3,9 +3,10 @@ import jax import jax.numpy as jnp -from horqrux import expectation, random_state +from horqrux import expectation, random_state, run from horqrux.parametric import RX from horqrux.primitive import Z +from horqrux.utils import density_mat N_QUBITS = 2 SHOTS_ATOL = 0.01 @@ -20,16 +21,43 @@ def test_shots() -> None: def exact(x): values = {"theta": x} - return expectation(state, ops, observables, values, "ad") + return expectation(state, ops, observables, values, diff_mode="ad") + + def exact_dm(x): + values = {"theta": x} + return expectation(density_mat(state), ops, observables, values, diff_mode="ad") def shots(x): values = {"theta": x} - return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) + return expectation( + state, ops, observables, values, diff_mode="gpsr", forward_mode="shots", n_shots=N_SHOTS + ) + + def shots_dm(x): + values = {"theta": x} + return expectation( + density_mat(state), + ops, + observables, + values, + diff_mode="gpsr", + forward_mode="shots", + n_shots=N_SHOTS, + ) + + expected_dm = density_mat(run(ops, state, {"theta": x})) + output_dm = run(ops, density_mat(state), {"theta": x}) + assert jnp.allclose(expected_dm.array, output_dm.array) exp_exact = exact(x) + exp_exact_dm = exact_dm(x) + assert jnp.allclose(exp_exact, exp_exact_dm) + exp_shots = shots(x) + exp_shots_dm = shots_dm(x) assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) + assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL) d_exact = jax.grad(lambda x: exact(x).sum()) d_shots = jax.grad(lambda x: shots(x).sum())