Skip to content

Commit

Permalink
feat: Small update to the Python API (#76)
Browse files Browse the repository at this point in the history
This PR contains a few small updates to the Python API for convenience.
Essentially it simply exposes the following methods:
* `VecGraph`:
    * `.adjoint`
    * `.plug`
    * `.clone`
* `Decomposer`:
    * `.done`
    * `.save`
    * `.decomp_parallel`

---------

Co-authored-by: Rafael Haenel <rhaenel@photonic.com>
  • Loading branch information
ColbyDeLisle and Rafael Haenel authored Oct 29, 2024
1 parent c13bf87 commit a7c3080
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 21 deletions.
12 changes: 3 additions & 9 deletions pybindings/quizx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from . import _quizx, simplify
from . import simplify
from .graph import VecGraph
from .circuit import Circuit
from .circuit import Circuit, extract_circuit
from .decompose import Decomposer
from ._quizx import Scalar

__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar"]


def extract_circuit(g):
c = Circuit()
c._c = _quizx.extract_circuit(g._g)
return c
__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar", "extract_circuit"]
6 changes: 6 additions & 0 deletions pybindings/quizx/_quizx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class VecGraph:
def outputs(self) -> list[int]: ...
def num_outputs(self) -> int: ...
def set_outputs(self, outputs: list[int]) -> None: ...
def adjoint(self) -> None: ...
def plug(self, other: "VecGraph") -> None: ...
def clone(self) -> "VecGraph": ...

@final
class Circuit:
Expand Down Expand Up @@ -95,10 +98,13 @@ class Decomposer:
def empty() -> Decomposer: ...
def __init__(self, g: VecGraph) -> None: ...
def graphs(self) -> list[VecGraph]: ...
def done(self) -> list[VecGraph]: ...
def save(self, b: bool) -> None: ...
def apply_optimizations(self, b: bool) -> None: ...
def max_terms(self) -> int: ...
def decomp_top(self) -> None: ...
def decomp_all(self) -> None: ...
def decomp_parallel(self, depth: int) -> None: ...
def decomp_until_depth(self, depth: int) -> None: ...
def use_cats(self, b: bool) -> None: ...
def get_nterms(self) -> int: ...
Expand Down
6 changes: 6 additions & 0 deletions pybindings/quizx/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from .graph import VecGraph


def extract_circuit(g: VecGraph) -> "Circuit":
c = Circuit()
c._c = _quizx.extract_circuit(g.get_raw_graph())
return c


class Circuit:
def __init__(self):
self._c = None
Expand Down
15 changes: 12 additions & 3 deletions pybindings/quizx/decompose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional

from . import _quizx
from .graph import VecGraph
Expand All @@ -15,8 +15,14 @@ def __init__(self, graph: Optional[VecGraph] = None):
else:
self._d = _quizx.Decomposer(graph.get_raw_graph())

def graphs(self) -> List[VecGraph]:
return [VecGraph(g) for g in self._d.graphs()]
def graphs(self) -> list[VecGraph]:
return [VecGraph.from_raw_graph(g) for g in self._d.graphs()]

def done(self) -> list[VecGraph]:
return [VecGraph.from_raw_graph(g) for g in self._d.done()]

def save(self, b: bool):
self._d.save(b)

def apply_optimizations(self, b: bool):
self._d.apply_optimizations(b)
Expand All @@ -30,6 +36,9 @@ def decomp_top(self):
def decomp_all(self):
self._d.decomp_all()

def decomp_parallel(self, depth: int = 4):
self._d.decomp_parallel(depth)

def decomp_until_depth(self, depth: int):
self._d.decomp_until_depth(depth)

Expand Down
25 changes: 17 additions & 8 deletions pybindings/quizx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .scalar import from_pyzx_scalar, to_pyzx_scalar
from fractions import Fraction
from typing import Tuple, Dict, Any, Optional
from typing import Tuple, Dict, Any
from pyzx.graph.base import BaseGraph # type: ignore
from pyzx.utils import VertexType, EdgeType # type: ignore
from pyzx.graph.scalar import Scalar
Expand All @@ -32,12 +32,9 @@ class VecGraph(BaseGraph[int, Tuple[int, int]]):

# The documentation of what these methods do
# can be found in base.BaseGraph
def __init__(self, rust_graph: Optional[_quizx.VecGraph] = None):
if rust_graph:
self._g = rust_graph
else:
self._g = _quizx.VecGraph()
BaseGraph.__init__(self)
def __init__(self) -> None:
self._g = _quizx.VecGraph()
super().__init__()
self._vdata: Dict[int, Any] = dict()

def get_raw_graph(self) -> _quizx.VecGraph:
Expand Down Expand Up @@ -172,7 +169,7 @@ def vertices_in_range(self, start, end):
for v in self.vertices():
if not start < v < end:
continue
if all(start < v2 < end for v2 in self.graph[v]):
if all(start < v2 < end for v2 in self.neighbors(v)):
yield v

def edges(self):
Expand Down Expand Up @@ -328,3 +325,15 @@ def scalar(self, s: Scalar):

def is_ground(self, vertex):
return False

def adjoint(self):
self._g.adjoint()

def plug(self, other: "VecGraph"):
if other._g is self._g:
self._g.plug(other._g.clone())
else:
self._g.plug(other._g)

def clone(self) -> "VecGraph":
return VecGraph.from_raw_graph(self._g.clone())
27 changes: 27 additions & 0 deletions pybindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,18 @@ impl VecGraph {
fn set_scalar(&mut self, scalar: Scalar) {
*self.g.scalar_mut() = scalar.into();
}

fn adjoint(&mut self) {
self.g.adjoint()
}

fn plug(&mut self, other: &VecGraph) {
self.g.plug(&other.g);
}

fn clone(&self) -> VecGraph {
VecGraph { g: self.g.clone() }
}
}

#[pyclass]
Expand Down Expand Up @@ -354,6 +366,18 @@ impl Decomposer {
Ok(gs)
}

fn done(&self) -> PyResult<Vec<VecGraph>> {
let mut gs = vec![];
for g in &self.d.done {
gs.push(VecGraph { g: g.clone() });
}
Ok(gs)
}

fn save(&mut self, b: bool) {
self.d.save(b);
}

fn apply_optimizations(&mut self, b: bool) {
if b {
self.d.with_simp(quizx::decompose::SimpFunc::FullSimp);
Expand All @@ -374,6 +398,9 @@ impl Decomposer {
fn decomp_until_depth(&mut self, depth: usize) {
self.d.decomp_until_depth(depth);
}
fn decomp_parallel(&mut self, depth: usize) {
self.d = self.d.clone().decomp_parallel(depth);
}
fn use_cats(&mut self, b: bool) {
self.d.use_cats(b);
}
Expand Down
2 changes: 1 addition & 1 deletion quizx/src/json/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl JsonScalar {

// In the Clifford+T case where we have Scalar4, we can extract factors of sqrt(2) directly from the
// coefficients. Since the coefficients are reduced, sqrt(2) is represented as
// [1, 0, +-1, 0], [0, 1, +-1, 0], where the +- lead to phase contributions already extracted in `phase`
// [1, 0, +-1, 0], [0, 1, 0, +-1], where the +- lead to phase contributions already extracted in `phase`
let (power_sqrt2, floatfactor) =
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
[a, 0, b, 0] | [0, a, 0, b]
Expand Down

0 comments on commit a7c3080

Please # to comment.