From 74f4433a48edec0f7ba396b5febb5c53721b2f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Wed, 27 Nov 2024 11:46:09 +0900 Subject: [PATCH 1/6] First version of jax notebook for v5 paper - added mesolve example - added count stat. example - added gradient example - added basic structure for notebook with introduction and explanations --- tutorials-v5/miscellaneous/v5_paper-jax.md | 324 +++++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 tutorials-v5/miscellaneous/v5_paper-jax.md diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md new file mode 100644 index 00000000..8f99f8e0 --- /dev/null +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -0,0 +1,324 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 + kernelspec: + display_name: qutip-tutorials-v5 + language: python + name: python3 +--- + +# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation + +Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Neill Lambert (nwlambert@gmail.com) + +For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. +Today, many libraries like TensorFlow, CuPy, and JAX enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. +QuTiP’s flexible data layer can directly be used with these libraries and thereby drastically reduce computation time. +Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration due to JAX's robust auto-differentiation features and widespread adoption in machine learning. + +In these examples we illustrate how JAX naturally integrates into QuTiP using the QuTiP-JAX package. +As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. +In the second part, we focus on the auto-differentiation capabilities. +For this we first consider the counting statistics of an open quantum system connected to an environment. +Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`. + +## Introduction + +In addition to the standrad QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. +This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX. + +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +import qutip_jax as qj +from diffrax import PIDController, Tsit5 +from jax import default_device, devices, grad, jacfwd, jacrev, jit +from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator, + liouvillian, mcsolve, mesolve, projection, qeye, settings, + sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost, + steadystate, tensor) + +%matplotlib inline +``` + +An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`. +They allow for dense (`jax`) and custom sparse (`jax_dia`) formats. + +```python +print(qeye(3, dtype="jax").dtype.__name__) +print(qeye(3, dtype="jaxdia").dtype.__name__) +``` + +In order to use the JAX data layer also within the master equation solver, there are two settings we can choose. +First, is simply adding `method: diffrax` to the options parameter of the solver. +Second, is to use the `qutip_jax.set_as_default()` method. +It automatically switches all data data tyeps to JAX compatible versions and sets the default solver method to `diffrax`. + +```python +qj.set_as_default() +``` + +To revert this setting, we can set the function parameter `revert = True`. + +```python +# qj.set_as_default(revert = True) +``` + +## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain + +Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. +GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. +For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. +However, as it is illustrated in the [QuTiPv5 paper](#References), there is a cross-over point at which using JAX is beneficial. + +### 1D Ising Spin Chain + +To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian + +$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$. + +We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$. +The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$. + +As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$. +In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindablad equation for this system using JAX and `mesolve()`. + +```python +# system parameters +N = 4 # number of spins +g0 = 1 # energy splitting +J0 = 1.4 # coupling strength +gamma = 0.1 # dissipation rate + +# simulation parameters +tlist = jnp.linspace(0, 5, 100) +opt = { + "normalize_output": False, + "store_states": True, + "method": "diffrax", + "stepsize_controller": PIDController( + rtol=settings.core["rtol"], atol=settings.core["atol"] + ), + "solver": Tsit5(), +} +``` + +```python +with CoreOptions(default_dtype="CSR"): + # Operators for individual qubits + sx_list, sy_list, sz_list = [], [], [] + for i in range(N): + op_list = [qeye(2)] * N + op_list[i] = sigmax() + sx_list.append(tensor(op_list)) + op_list[i] = sigmay() + sy_list.append(tensor(op_list)) + op_list[i] = sigmaz() + sz_list.append(tensor(op_list)) + + # Hamiltonian - Energy splitting terms + H = 0.0 + for i in range(N): + H += g0 * sz_list[i] + + # Interaction terms + for n in range(N - 1): + H += -J0 * sx_list[n] * sx_list[n + 1] + + # Collapse operator acting locally on single spin + c_ops = [gamma * sx_list[N - 1]] + + # Initial state + state_list = [basis(2, 1)] * (N - 1) + state_list.append(basis(2, 0)) + psi0 = tensor(state_list) + + result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt) +``` + +```python +for i, s in enumerate(result.expect): + plt.plot(tlist, s, label=rf"$n = {i+1}$") +plt.xlabel("Time") +plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$") +plt.legend() +plt.show() +``` + +## Auto-Differentiation + +We have seen in the previous example how the new JAX data-layer in QuTiP enables us to run calculations on the GPU. +On top of that, JAX adds the features of auto-differentiation. +To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. +Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. + +Auto-differenciation, on the other hand, exploits the chain rule to compute such derivatives. +The idea is that any numerical function can be expressed by elementary analytical functions and operations. +Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. + +Although there are many applications for this technique, in chapter we want to focus two examples where auto-differentiation becomes relevant. + +### Statistics of Excitations between Quantum System and Environment + +We consider an open quantum system that is in contact with an evironment via a single jump operator. +Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment. +The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$. +This statistics is a defining property that allows to derive many experimental observables like shot noise or current. + +For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation. +We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$. +The master equation for this operator, including the jump operator $C$, is then given as + +$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$. + +We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$. +However, it also allows us to describe the counting statistics through its derivatives + +$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$. + +These derivatives are precisely where the auto-differention by JAX finds its application for us. + +```python +# system parameters +ed = 1 +GammaL = 1 +GammaR = 1 + +# simulation parameters +options = { + "method": "diffrax", + "normalize_output": False, + "stepsize_controller": PIDController(rtol=1e-7, atol=1e-7), + "solver": Tsit5(scan_kind="bounded"), + "progress_bar": False, +} +``` + +```python +with default_device(devices("gpu")[0]): + with CoreOptions(default_dtype="jaxdia"): + d = destroy(2) + H = ed * d.dag() * d + c_op_L = jnp.sqrt(GammaL) * d.dag() + c_op_R = jnp.sqrt(GammaR) * d + + L0 = ( + liouvillian(H) + + lindblad_dissipator(c_op_L) + - 0.5 * spre(c_op_R.dag() * c_op_R) + - 0.5 * spost(c_op_R.dag() * c_op_R) + ) + L1 = sprepost(c_op_R, c_op_R.dag()) + + rho0 = steadystate(L0 + L1) + + def rhoz(t, z): + L = L0 + jnp.exp(z) * L1 # jump term + tlist = jnp.linspace(0, t, 50) + result = mesolve(L, rho0, tlist, options=options) + return result.final_state.tr() + + # first derivative + drhozdz = jacrev(rhoz, argnums=1) + # second derivative + d2rhozdz = jacfwd(drhozdz, argnums=1) +``` + +```python +tf = 100 +Itest = GammaL * GammaR / (GammaL + GammaR) +shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2) +ncurr = drhozdz(tf, 0.0) / tf +nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf + +print("===== RESULTS =====") +print("Analytic current", Itest) +print("Numerical current", ncurr) +print("Analytical shot noise (2nd cumulant)", shottest) +print("Numerical shot noise (2nd cumulant)", nshot) +``` + +### Driven One Qubit System & Frequency Optimization + +As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian + +$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$ + +with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices. +When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation. + +For this example, we are interested in the population of the excited state of the qubit + +$P_e(t) = \bra{e} \rho(t) \ket{e}$ + +and its gradient with respect to the frequency $\omega$. + +We want to optimize this quantity by adjusting the drive frequency $\omega$. +To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools nad QuTiP's `mcsolve()`. + +```python +# system parameters +gamma = 0.1 # dissipation rate +``` + +```python +# time dependent drive +@jit +def driving_coeff(t, omega): + return jnp.cos(omega * t) + + +# system Hamiltonian +def setup_system(omega): + H_0 = sigmaz() + H_1 = sigmax() + H = [H_0, [H_1, driving_coeff]] + return H +``` + +```python +# simulation parameters +psi0 = basis(2, 0) +tlist = jnp.linspace(0.0, 10.0, 100) +c_ops = [jnp.sqrt(gamma) * sigmam()] +e_ops = [projection(2, 1, 1)] +``` + +```python +# Objective function: returns final exc. state population +def f(omega): + H = setup_system(omega) + arg = {"omega": omega} + result = mcsolve(H, psi0, tlist, c_ops, e_ops, ntraj=100, args=arg) + return result.expect[0][-1] +``` + +```python +# Compute gradient of the exc. state population w.r.t. omega +grad_f = grad(f)(2.0) +``` + +## References + + + +```python +# TODO +``` + +## About + +```python +about() +``` + +## Testing + +```python +# TODO +``` From 5ff9708ac5d8d9b91e940ba9d35bbe9044d44f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Wed, 27 Nov 2024 11:50:15 +0900 Subject: [PATCH 2/6] Corrected python display name --- tutorials-v5/miscellaneous/v5_paper-jax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md index 8f99f8e0..29c0bc95 100644 --- a/tutorials-v5/miscellaneous/v5_paper-jax.md +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -7,7 +7,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: qutip-tutorials-v5 + display_name: Python 3 language: python name: python3 --- From b00fff2ba7f16f3dbf13fe6bd37e1fd4f9ac66a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Thu, 28 Nov 2024 16:20:21 +0900 Subject: [PATCH 3/6] Added references, basic test - added qutip5 paper and qutip-jax as reference - added basics tests for current and shot noise example - small code readbility improvement --- tutorials-v5/miscellaneous/v5_paper-jax.md | 25 ++++++++++++---------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md index 29c0bc95..30fd9e8c 100644 --- a/tutorials-v5/miscellaneous/v5_paper-jax.md +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -17,11 +17,11 @@ jupyter: Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Neill Lambert (nwlambert@gmail.com) For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. -Today, many libraries like TensorFlow, CuPy, and JAX enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. -QuTiP’s flexible data layer can directly be used with these libraries and thereby drastically reduce computation time. -Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration due to JAX's robust auto-differentiation features and widespread adoption in machine learning. +Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. +QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time. +Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]] due to JAX's robust auto-differentiation features and widespread adoption in machine learning. -In these examples we illustrate how JAX naturally integrates into QuTiP using the QuTiP-JAX package. +In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]] using the QuTiP-JAX package. As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. In the second part, we focus on the auto-differentiation capabilities. For this we first consider the counting statistics of an open quantum system connected to an environment. @@ -274,7 +274,7 @@ def driving_coeff(t, omega): # system Hamiltonian -def setup_system(omega): +def setup_system(): H_0 = sigmaz() H_1 = sigmax() H = [H_0, [H_1, driving_coeff]] @@ -292,9 +292,9 @@ e_ops = [projection(2, 1, 1)] ```python # Objective function: returns final exc. state population def f(omega): - H = setup_system(omega) + H = setup_system() arg = {"omega": omega} - result = mcsolve(H, psi0, tlist, c_ops, e_ops, ntraj=100, args=arg) + result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg) return result.expect[0][-1] ``` @@ -307,9 +307,11 @@ grad_f = grad(f)(2.0) -```python -# TODO -``` + +[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) + +[2] [QuTiP v5: The Quantum Toolbox in Python](about:blank) + ## About @@ -320,5 +322,6 @@ about() ## Testing ```python -# TODO +assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates" +assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates." ``` From df0ad304a5c1fe411146ffa643b824fa1835f9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Mon, 2 Dec 2024 14:57:46 +0900 Subject: [PATCH 4/6] Fixed typos, added author --- tutorials-v5/miscellaneous/v5_paper-jax.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md index 30fd9e8c..9a6b7bdb 100644 --- a/tutorials-v5/miscellaneous/v5_paper-jax.md +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -7,14 +7,14 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: Python 3 + display_name: qutip-tutorials-v5 language: python name: python3 --- # QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation -Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Neill Lambert (nwlambert@gmail.com) +Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Rochisha Agarwal (rochisha.agarwal2302@gmail.com), Neill Lambert (nwlambert@gmail.com) For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. @@ -29,7 +29,7 @@ Lastly, we look at a driven qubit system and computing the gradient of its state ## Introduction -In addition to the standrad QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. +In addition to the standard QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX. ```python @@ -57,7 +57,7 @@ print(qeye(3, dtype="jaxdia").dtype.__name__) In order to use the JAX data layer also within the master equation solver, there are two settings we can choose. First, is simply adding `method: diffrax` to the options parameter of the solver. Second, is to use the `qutip_jax.set_as_default()` method. -It automatically switches all data data tyeps to JAX compatible versions and sets the default solver method to `diffrax`. +It automatically switches all data data types to JAX compatible versions and sets the default solver method to `diffrax`. ```python qj.set_as_default() @@ -86,7 +86,7 @@ We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$. As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$. -In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindablad equation for this system using JAX and `mesolve()`. +In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindblad equation for this system using JAX and `mesolve()`. ```python # system parameters @@ -157,7 +157,7 @@ On top of that, JAX adds the features of auto-differentiation. To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. -Auto-differenciation, on the other hand, exploits the chain rule to compute such derivatives. +Auto-differentiation, on the other hand, exploits the chain rule to compute such derivatives. The idea is that any numerical function can be expressed by elementary analytical functions and operations. Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. @@ -259,7 +259,7 @@ $P_e(t) = \bra{e} \rho(t) \ket{e}$ and its gradient with respect to the frequency $\omega$. We want to optimize this quantity by adjusting the drive frequency $\omega$. -To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools nad QuTiP's `mcsolve()`. +To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools and QuTiP's `mcsolve()`. ```python # system parameters From 2e5413489ff724644754721cf1ad8d5d06bfa755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Mon, 9 Dec 2024 14:13:57 +0900 Subject: [PATCH 5/6] Added v5 paper, fixed typos --- tutorials-v5/miscellaneous/v5_paper-jax.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md index 9a6b7bdb..50d0679f 100644 --- a/tutorials-v5/miscellaneous/v5_paper-jax.md +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -7,7 +7,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: qutip-tutorials-v5 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -19,9 +19,9 @@ Authors: Maximilian Meyer-Mölleringhof (m.meyermoelleringhof@gmail.com), Rochis For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time. -Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]] due to JAX's robust auto-differentiation features and widespread adoption in machine learning. +Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]](#References) due to JAX's robust auto-differentiation features and widespread adoption in machine learning. -In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]] using the QuTiP-JAX package. +In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]](#References) using the QuTiP-JAX package. As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. In the second part, we focus on the auto-differentiation capabilities. For this we first consider the counting statistics of an open quantum system connected to an environment. @@ -74,7 +74,7 @@ To revert this setting, we can set the function parameter `revert = True`. Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. -However, as it is illustrated in the [QuTiPv5 paper](#References), there is a cross-over point at which using JAX is beneficial. +However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX is beneficial. ### 1D Ising Spin Chain @@ -161,7 +161,7 @@ Auto-differentiation, on the other hand, exploits the chain rule to compute such The idea is that any numerical function can be expressed by elementary analytical functions and operations. Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. -Although there are many applications for this technique, in chapter we want to focus two examples where auto-differentiation becomes relevant. +Although there are many applications for this technique, in this chapter we want to focus on two examples where auto-differentiation becomes relevant. ### Statistics of Excitations between Quantum System and Environment @@ -310,7 +310,7 @@ grad_f = grad(f)(2.0) [1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) -[2] [QuTiP v5: The Quantum Toolbox in Python](about:blank) +[2] [QuTiP v5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705) ## About From a07e0e725116d4ac73bbee3228be028a5ce4652f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20Meyer-M=C3=B6lleringhof?= Date: Fri, 28 Feb 2025 15:21:01 +0900 Subject: [PATCH 6/6] Corrected hardware choice so CI test succeed --- tutorials-v5/miscellaneous/v5_paper-jax.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tutorials-v5/miscellaneous/v5_paper-jax.md b/tutorials-v5/miscellaneous/v5_paper-jax.md index 50d0679f..6d6582fa 100644 --- a/tutorials-v5/miscellaneous/v5_paper-jax.md +++ b/tutorials-v5/miscellaneous/v5_paper-jax.md @@ -7,7 +7,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: Python 3 (ipykernel) + display_name: qutip-tutorials-v5 language: python name: python3 --- @@ -74,7 +74,7 @@ To revert this setting, we can set the function parameter `revert = True`. Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. -However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX is beneficial. +However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX becomes beneficial. ### 1D Ising Spin Chain @@ -152,7 +152,7 @@ plt.show() ## Auto-Differentiation -We have seen in the previous example how the new JAX data-layer in QuTiP enables us to run calculations on the GPU. +We have seen in the previous example how the new JAX data-layer in QuTiP works. On top of that, JAX adds the features of auto-differentiation. To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. @@ -199,8 +199,12 @@ options = { } ``` +When working with JAX you can choose the type of device / processor to be used. +In our case, we will resort to the CPU since this is a simple Jupyter Notebook. +However, when running this on your machine, you can opt for using your GPU by simpy changing the argument below. + ```python -with default_device(devices("gpu")[0]): +with default_device(devices("cpu")[0]): with CoreOptions(default_dtype="jaxdia"): d = destroy(2) H = ed * d.dag() * d @@ -310,7 +314,7 @@ grad_f = grad(f)(2.0) [1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) -[2] [QuTiP v5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705) +[2] [QuTiP 5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705) ## About