Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

QuTiPv5 Paper Notebook: JAX #116

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions tutorials-v5/miscellaneous/v5_paper-jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
---
jupyter:
jupytext:
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.13.8
kernelspec:
display_name: qutip-tutorials-v5
language: python
name: python3
---

# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation

Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Rochisha Agarwal (rochisha.agarwal2302@gmail.com), Neill Lambert (nwlambert@gmail.com)

For many years now, GPUs have been a fundamental tool for accelerating numerical tasks.
Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations.
QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time.
Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]](#References) due to JAX's robust auto-differentiation features and widespread adoption in machine learning.

In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]](#References) using the QuTiP-JAX package.
As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation.
In the second part, we focus on the auto-differentiation capabilities.
For this we first consider the counting statistics of an open quantum system connected to an environment.
Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`.

## Introduction

In addition to the standard QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed.
This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX.

```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
import qutip_jax as qj
from diffrax import PIDController, Tsit5
from jax import default_device, devices, grad, jacfwd, jacrev, jit
from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator,
liouvillian, mcsolve, mesolve, projection, qeye, settings,
sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost,
steadystate, tensor)

%matplotlib inline
```

An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`.
They allow for dense (`jax`) and custom sparse (`jax_dia`) formats.

```python
print(qeye(3, dtype="jax").dtype.__name__)
print(qeye(3, dtype="jaxdia").dtype.__name__)
```

In order to use the JAX data layer also within the master equation solver, there are two settings we can choose.
First, is simply adding `method: diffrax` to the options parameter of the solver.
Second, is to use the `qutip_jax.set_as_default()` method.
It automatically switches all data data types to JAX compatible versions and sets the default solver method to `diffrax`.

```python
qj.set_as_default()
```

To revert this setting, we can set the function parameter `revert = True`.

```python
# qj.set_as_default(revert = True)
```

## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain

Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem.
GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations.
For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential.
However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX becomes beneficial.

### 1D Ising Spin Chain

To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian

$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$.

We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$.
The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$.

As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$.
In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindblad equation for this system using JAX and `mesolve()`.

```python
# system parameters
N = 4 # number of spins
g0 = 1 # energy splitting
J0 = 1.4 # coupling strength
gamma = 0.1 # dissipation rate

# simulation parameters
tlist = jnp.linspace(0, 5, 100)
opt = {
"normalize_output": False,
"store_states": True,
"method": "diffrax",
"stepsize_controller": PIDController(
rtol=settings.core["rtol"], atol=settings.core["atol"]
),
"solver": Tsit5(),
}
```

```python
with CoreOptions(default_dtype="CSR"):
# Operators for individual qubits
sx_list, sy_list, sz_list = [], [], []
for i in range(N):
op_list = [qeye(2)] * N
op_list[i] = sigmax()
sx_list.append(tensor(op_list))
op_list[i] = sigmay()
sy_list.append(tensor(op_list))
op_list[i] = sigmaz()
sz_list.append(tensor(op_list))

# Hamiltonian - Energy splitting terms
H = 0.0
for i in range(N):
H += g0 * sz_list[i]

# Interaction terms
for n in range(N - 1):
H += -J0 * sx_list[n] * sx_list[n + 1]

# Collapse operator acting locally on single spin
c_ops = [gamma * sx_list[N - 1]]

# Initial state
state_list = [basis(2, 1)] * (N - 1)
state_list.append(basis(2, 0))
psi0 = tensor(state_list)

result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt)
```

```python
for i, s in enumerate(result.expect):
plt.plot(tlist, s, label=rf"$n = {i+1}$")
plt.xlabel("Time")
plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$")
plt.legend()
plt.show()
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add some funky animation plot!

%matplotlib notebook
from IPython.display import HTML
ax, ani = qt.anim_matrix_histogram(result_ising)
HTML(ani.to_jshtml(fps=15))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal that this takes >20min to compute (on a normal-ish laptop)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ouch. i tried a different example the other day and it was fairly slow too. we could reduce the number of points in tlist, but perhaps just better not to add here!

Copy link
Member

@BoxiLi BoxiLi Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, a long tlist usually causes a long compilation time (not computation time!), because we have a for loop in the qutip base solver. This for loop will be flattened by jax compilation. So each step in the for loop is compiled separately. This can be speeded up by rewriting it with a JAX loop. But I don't know how hard it is because in each step we add something to the result class.

## Auto-Differentiation

We have seen in the previous example how the new JAX data-layer in QuTiP works.
On top of that, JAX adds the features of auto-differentiation.
To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed.
Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations.

Auto-differentiation, on the other hand, exploits the chain rule to compute such derivatives.
The idea is that any numerical function can be expressed by elementary analytical functions and operations.
Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible.

Although there are many applications for this technique, in this chapter we want to focus on two examples where auto-differentiation becomes relevant.

### Statistics of Excitations between Quantum System and Environment

We consider an open quantum system that is in contact with an evironment via a single jump operator.
Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment.
The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$.
This statistics is a defining property that allows to derive many experimental observables like shot noise or current.

For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation.
We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$.
The master equation for this operator, including the jump operator $C$, is then given as

$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$.

We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$.
However, it also allows us to describe the counting statistics through its derivatives

$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$.

These derivatives are precisely where the auto-differention by JAX finds its application for us.

```python
# system parameters
ed = 1
GammaL = 1
GammaR = 1

# simulation parameters
options = {
"method": "diffrax",
"normalize_output": False,
"stepsize_controller": PIDController(rtol=1e-7, atol=1e-7),
"solver": Tsit5(scan_kind="bounded"),
"progress_bar": False,
}
```

When working with JAX you can choose the type of device / processor to be used.
In our case, we will resort to the CPU since this is a simple Jupyter Notebook.
However, when running this on your machine, you can opt for using your GPU by simpy changing the argument below.

```python
with default_device(devices("cpu")[0]):
with CoreOptions(default_dtype="jaxdia"):
d = destroy(2)
H = ed * d.dag() * d
c_op_L = jnp.sqrt(GammaL) * d.dag()
c_op_R = jnp.sqrt(GammaR) * d

L0 = (
liouvillian(H)
+ lindblad_dissipator(c_op_L)
- 0.5 * spre(c_op_R.dag() * c_op_R)
- 0.5 * spost(c_op_R.dag() * c_op_R)
)
L1 = sprepost(c_op_R, c_op_R.dag())

rho0 = steadystate(L0 + L1)

def rhoz(t, z):
L = L0 + jnp.exp(z) * L1 # jump term
tlist = jnp.linspace(0, t, 50)
result = mesolve(L, rho0, tlist, options=options)
return result.final_state.tr()

# first derivative
drhozdz = jacrev(rhoz, argnums=1)
# second derivative
d2rhozdz = jacfwd(drhozdz, argnums=1)
```

```python
tf = 100
Itest = GammaL * GammaR / (GammaL + GammaR)
shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2)
ncurr = drhozdz(tf, 0.0) / tf
nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf

print("===== RESULTS =====")
print("Analytic current", Itest)
print("Numerical current", ncurr)
print("Analytical shot noise (2nd cumulant)", shottest)
print("Numerical shot noise (2nd cumulant)", nshot)
```

### Driven One Qubit System & Frequency Optimization

As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian

$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$

with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices.
When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation.

For this example, we are interested in the population of the excited state of the qubit

$P_e(t) = \bra{e} \rho(t) \ket{e}$

and its gradient with respect to the frequency $\omega$.

We want to optimize this quantity by adjusting the drive frequency $\omega$.
To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools and QuTiP's `mcsolve()`.

```python
# system parameters
gamma = 0.1 # dissipation rate
```

```python
# time dependent drive
@jit
def driving_coeff(t, omega):
return jnp.cos(omega * t)


# system Hamiltonian
def setup_system():
H_0 = sigmaz()
H_1 = sigmax()
H = [H_0, [H_1, driving_coeff]]
return H
```

```python
# simulation parameters
psi0 = basis(2, 0)
tlist = jnp.linspace(0.0, 10.0, 100)
c_ops = [jnp.sqrt(gamma) * sigmam()]
e_ops = [projection(2, 1, 1)]
```

```python
# Objective function: returns final exc. state population
def f(omega):
H = setup_system()
arg = {"omega": omega}
result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg)
return result.expect[0][-1]
```

```python
# Compute gradient of the exc. state population w.r.t. omega
grad_f = grad(f)(2.0)
```

## References




[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax)

[2] [QuTiP 5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705)


## About

```python
about()
```

## Testing

```python
assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates"
assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates."
```
Loading