-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
QuTiPv5 Paper Notebook: JAX #116
Open
Langhaarzombie
wants to merge
7
commits into
qutip:main
Choose a base branch
from
Langhaarzombie:feature/ntbk_jax
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
74f4433
First version of jax notebook for v5 paper
Langhaarzombie 5ff9708
Corrected python display name
Langhaarzombie b00fff2
Added references, basic test
Langhaarzombie df0ad30
Fixed typos, added author
Langhaarzombie 2e54134
Added v5 paper, fixed typos
Langhaarzombie dcb7d77
Merge branch 'qutip:main' into feature/ntbk_jax
Langhaarzombie a07e0e7
Corrected hardware choice so CI test succeed
Langhaarzombie File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,331 @@ | ||
--- | ||
jupyter: | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: markdown | ||
format_version: '1.3' | ||
jupytext_version: 1.13.8 | ||
kernelspec: | ||
display_name: qutip-tutorials-v5 | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation | ||
|
||
Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Rochisha Agarwal (rochisha.agarwal2302@gmail.com), Neill Lambert (nwlambert@gmail.com) | ||
|
||
For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. | ||
Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. | ||
QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time. | ||
Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]](#References) due to JAX's robust auto-differentiation features and widespread adoption in machine learning. | ||
|
||
In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]](#References) using the QuTiP-JAX package. | ||
As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. | ||
In the second part, we focus on the auto-differentiation capabilities. | ||
For this we first consider the counting statistics of an open quantum system connected to an environment. | ||
Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`. | ||
|
||
## Introduction | ||
|
||
In addition to the standard QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. | ||
This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX. | ||
|
||
```python | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import qutip_jax as qj | ||
from diffrax import PIDController, Tsit5 | ||
from jax import default_device, devices, grad, jacfwd, jacrev, jit | ||
from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator, | ||
liouvillian, mcsolve, mesolve, projection, qeye, settings, | ||
sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost, | ||
steadystate, tensor) | ||
|
||
%matplotlib inline | ||
``` | ||
|
||
An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`. | ||
They allow for dense (`jax`) and custom sparse (`jax_dia`) formats. | ||
|
||
```python | ||
print(qeye(3, dtype="jax").dtype.__name__) | ||
print(qeye(3, dtype="jaxdia").dtype.__name__) | ||
``` | ||
|
||
In order to use the JAX data layer also within the master equation solver, there are two settings we can choose. | ||
First, is simply adding `method: diffrax` to the options parameter of the solver. | ||
Second, is to use the `qutip_jax.set_as_default()` method. | ||
It automatically switches all data data types to JAX compatible versions and sets the default solver method to `diffrax`. | ||
|
||
```python | ||
qj.set_as_default() | ||
``` | ||
|
||
To revert this setting, we can set the function parameter `revert = True`. | ||
|
||
```python | ||
# qj.set_as_default(revert = True) | ||
``` | ||
|
||
## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain | ||
|
||
Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. | ||
GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. | ||
For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. | ||
However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX becomes beneficial. | ||
|
||
### 1D Ising Spin Chain | ||
|
||
To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian | ||
|
||
$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$. | ||
|
||
We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$. | ||
The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$. | ||
|
||
As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$. | ||
In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindblad equation for this system using JAX and `mesolve()`. | ||
|
||
```python | ||
# system parameters | ||
N = 4 # number of spins | ||
g0 = 1 # energy splitting | ||
J0 = 1.4 # coupling strength | ||
gamma = 0.1 # dissipation rate | ||
|
||
# simulation parameters | ||
tlist = jnp.linspace(0, 5, 100) | ||
opt = { | ||
"normalize_output": False, | ||
"store_states": True, | ||
"method": "diffrax", | ||
"stepsize_controller": PIDController( | ||
rtol=settings.core["rtol"], atol=settings.core["atol"] | ||
), | ||
"solver": Tsit5(), | ||
} | ||
``` | ||
|
||
```python | ||
with CoreOptions(default_dtype="CSR"): | ||
# Operators for individual qubits | ||
sx_list, sy_list, sz_list = [], [], [] | ||
for i in range(N): | ||
op_list = [qeye(2)] * N | ||
op_list[i] = sigmax() | ||
sx_list.append(tensor(op_list)) | ||
op_list[i] = sigmay() | ||
sy_list.append(tensor(op_list)) | ||
op_list[i] = sigmaz() | ||
sz_list.append(tensor(op_list)) | ||
|
||
# Hamiltonian - Energy splitting terms | ||
H = 0.0 | ||
for i in range(N): | ||
H += g0 * sz_list[i] | ||
|
||
# Interaction terms | ||
for n in range(N - 1): | ||
H += -J0 * sx_list[n] * sx_list[n + 1] | ||
|
||
# Collapse operator acting locally on single spin | ||
c_ops = [gamma * sx_list[N - 1]] | ||
|
||
# Initial state | ||
state_list = [basis(2, 1)] * (N - 1) | ||
state_list.append(basis(2, 0)) | ||
psi0 = tensor(state_list) | ||
|
||
result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt) | ||
``` | ||
|
||
```python | ||
for i, s in enumerate(result.expect): | ||
plt.plot(tlist, s, label=rf"$n = {i+1}$") | ||
plt.xlabel("Time") | ||
plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$") | ||
plt.legend() | ||
plt.show() | ||
``` | ||
|
||
## Auto-Differentiation | ||
|
||
We have seen in the previous example how the new JAX data-layer in QuTiP works. | ||
On top of that, JAX adds the features of auto-differentiation. | ||
To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. | ||
Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. | ||
|
||
Auto-differentiation, on the other hand, exploits the chain rule to compute such derivatives. | ||
The idea is that any numerical function can be expressed by elementary analytical functions and operations. | ||
Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. | ||
|
||
Although there are many applications for this technique, in this chapter we want to focus on two examples where auto-differentiation becomes relevant. | ||
|
||
### Statistics of Excitations between Quantum System and Environment | ||
|
||
We consider an open quantum system that is in contact with an evironment via a single jump operator. | ||
Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment. | ||
The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$. | ||
This statistics is a defining property that allows to derive many experimental observables like shot noise or current. | ||
|
||
For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation. | ||
We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$. | ||
The master equation for this operator, including the jump operator $C$, is then given as | ||
|
||
$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$. | ||
|
||
We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$. | ||
However, it also allows us to describe the counting statistics through its derivatives | ||
|
||
$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$. | ||
|
||
These derivatives are precisely where the auto-differention by JAX finds its application for us. | ||
|
||
```python | ||
# system parameters | ||
ed = 1 | ||
GammaL = 1 | ||
GammaR = 1 | ||
|
||
# simulation parameters | ||
options = { | ||
"method": "diffrax", | ||
"normalize_output": False, | ||
"stepsize_controller": PIDController(rtol=1e-7, atol=1e-7), | ||
"solver": Tsit5(scan_kind="bounded"), | ||
"progress_bar": False, | ||
} | ||
``` | ||
|
||
When working with JAX you can choose the type of device / processor to be used. | ||
In our case, we will resort to the CPU since this is a simple Jupyter Notebook. | ||
However, when running this on your machine, you can opt for using your GPU by simpy changing the argument below. | ||
|
||
```python | ||
with default_device(devices("cpu")[0]): | ||
with CoreOptions(default_dtype="jaxdia"): | ||
d = destroy(2) | ||
H = ed * d.dag() * d | ||
c_op_L = jnp.sqrt(GammaL) * d.dag() | ||
c_op_R = jnp.sqrt(GammaR) * d | ||
|
||
L0 = ( | ||
liouvillian(H) | ||
+ lindblad_dissipator(c_op_L) | ||
- 0.5 * spre(c_op_R.dag() * c_op_R) | ||
- 0.5 * spost(c_op_R.dag() * c_op_R) | ||
) | ||
L1 = sprepost(c_op_R, c_op_R.dag()) | ||
|
||
rho0 = steadystate(L0 + L1) | ||
|
||
def rhoz(t, z): | ||
L = L0 + jnp.exp(z) * L1 # jump term | ||
tlist = jnp.linspace(0, t, 50) | ||
result = mesolve(L, rho0, tlist, options=options) | ||
return result.final_state.tr() | ||
|
||
# first derivative | ||
drhozdz = jacrev(rhoz, argnums=1) | ||
# second derivative | ||
d2rhozdz = jacfwd(drhozdz, argnums=1) | ||
``` | ||
|
||
```python | ||
tf = 100 | ||
Itest = GammaL * GammaR / (GammaL + GammaR) | ||
shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2) | ||
ncurr = drhozdz(tf, 0.0) / tf | ||
nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf | ||
|
||
print("===== RESULTS =====") | ||
print("Analytic current", Itest) | ||
print("Numerical current", ncurr) | ||
print("Analytical shot noise (2nd cumulant)", shottest) | ||
print("Numerical shot noise (2nd cumulant)", nshot) | ||
``` | ||
|
||
### Driven One Qubit System & Frequency Optimization | ||
|
||
As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian | ||
|
||
$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$ | ||
|
||
with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices. | ||
When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation. | ||
|
||
For this example, we are interested in the population of the excited state of the qubit | ||
|
||
$P_e(t) = \bra{e} \rho(t) \ket{e}$ | ||
|
||
and its gradient with respect to the frequency $\omega$. | ||
|
||
We want to optimize this quantity by adjusting the drive frequency $\omega$. | ||
To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools and QuTiP's `mcsolve()`. | ||
|
||
```python | ||
# system parameters | ||
gamma = 0.1 # dissipation rate | ||
``` | ||
|
||
```python | ||
# time dependent drive | ||
@jit | ||
def driving_coeff(t, omega): | ||
return jnp.cos(omega * t) | ||
|
||
|
||
# system Hamiltonian | ||
def setup_system(): | ||
H_0 = sigmaz() | ||
H_1 = sigmax() | ||
H = [H_0, [H_1, driving_coeff]] | ||
return H | ||
``` | ||
|
||
```python | ||
# simulation parameters | ||
psi0 = basis(2, 0) | ||
tlist = jnp.linspace(0.0, 10.0, 100) | ||
c_ops = [jnp.sqrt(gamma) * sigmam()] | ||
e_ops = [projection(2, 1, 1)] | ||
``` | ||
|
||
```python | ||
# Objective function: returns final exc. state population | ||
def f(omega): | ||
H = setup_system() | ||
arg = {"omega": omega} | ||
result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg) | ||
return result.expect[0][-1] | ||
``` | ||
|
||
```python | ||
# Compute gradient of the exc. state population w.r.t. omega | ||
grad_f = grad(f)(2.0) | ||
``` | ||
|
||
## References | ||
|
||
|
||
|
||
|
||
[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) | ||
|
||
[2] [QuTiP 5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705) | ||
|
||
|
||
## About | ||
|
||
```python | ||
about() | ||
``` | ||
|
||
## Testing | ||
|
||
```python | ||
assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates" | ||
assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates." | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add some funky animation plot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it normal that this takes >20min to compute (on a normal-ish laptop)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ouch. i tried a different example the other day and it was fairly slow too. we could reduce the number of points in tlist, but perhaps just better not to add here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience, a long
tlist
usually causes a long compilation time (not computation time!), because we have a for loop in the qutip base solver. This for loop will be flattened by jax compilation. So each step in the for loop is compiled separately. This can be speeded up by rewriting it with a JAX loop. But I don't know how hard it is because in each step we add something to the result class.