Skip to content

Commit

Permalink
add spectral gap via gates
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Mar 9, 2025
1 parent 12361d7 commit 14abf0b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
6 changes: 2 additions & 4 deletions horqrux/differentiation/gpsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ def finite_shots_jvp(
values = primals[0]
tangent_dict = tangents[0]

# TODO: compute spectral gap through the generator which is associated with
# a param name.
spectral_gap = 2.0
shift = jnp.pi / 2

Expand Down Expand Up @@ -274,6 +272,7 @@ def jvp_component_repeated_param(param_name: str, key: Array) -> Array:

def shift_jvp(ind: int, key: Array) -> Array:
up_key, down_key = random.split(key)
spectral_gap = gates[ind].spectral_gap # type: ignore[index]
gates_up = alter_gate_sequence(gates, ind, shift)
f_up = finite_shots_fwd(
state, gates_up, observable, values, n_shots, up_key
Expand Down Expand Up @@ -309,8 +308,6 @@ def no_shots_fwd_jvp(
values = primals[0]
tangent_dict = tangents[0]

# TODO: compute spectral gap through the generator which is associated with
# a param name.
spectral_gap = 2.0
shift = jnp.pi / 2

Expand Down Expand Up @@ -344,6 +341,7 @@ def jvp_component_repeated_param(param_name: str) -> Array:
shift_gates = param_to_gates_indices[param_name]

def shift_jvp(ind: int) -> Array:
spectral_gap = gates[ind].spectral_gap # type: ignore[index]
gates_up = alter_gate_sequence(gates, ind, shift)
f_up = no_shots_fwd(state, gates_up, observable, values)
gates_down = alter_gate_sequence(gates, ind, -shift)
Expand Down
35 changes: 35 additions & 0 deletions horqrux/primitives/parametric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import Any, Iterable

import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import register_pytree_node_class
Expand All @@ -22,6 +24,8 @@
from .primitive import Primitive

default_dtype = default_complex_dtype()
nonzero_jit = jax.jit(jnp.nonzero, static_argnames="size")
unique_jit = jax.jit(jnp.unique, static_argnames="size")


@register_pytree_node_class
Expand Down Expand Up @@ -88,6 +92,37 @@ def __repr__(self) -> str:
+ f"(target={self.target}, control={self.control}, param={self.param}, shift={self.shift})"
)

@cached_property
def eigenvals_generator(self) -> Array:
"""Get eigenvalues of the underlying operation.
Arguments:
values: Parameter values.
Returns:
Array: Eigenvalues of the operation.
"""
eig_vals_generator = jnp.linalg.eigvalsh(OPERATIONS_DICT[self.generator_name])
if is_controlled(self.control):
eig_vals_generator = jnp.concatenate(
(jnp.zeros(2 ** (len(self.control))), eig_vals_generator)
)
return eig_vals_generator

@cached_property
def spectral_gap(self) -> Array:
"""Difference between the moduli of the two largest eigenvalues of the generator.
Returns:
Array: Spectral gap value.
"""
spectrum = jnp.atleast_2d(self.eigenvals_generator)
diffs = spectrum - spectrum.T
# note for jitting, must specify a size
# atm only size 2 is acceptable given all possible generators in OPERATIONS_DICT
spectral_gap = unique_jit(jnp.abs(jnp.tril(diffs)), size=2)
return spectral_gap[nonzero_jit(spectral_gap, size=1)]


def RX(
param: float | str,
Expand Down

0 comments on commit 14abf0b

Please # to comment.