diff --git a/tutorials-v5/quantum-optimal-control/Single_Qubit_RL.ipynb b/tutorials-v5/quantum-optimal-control/Single_Qubit_RL.ipynb new file mode 100644 index 0000000..e95c9e0 --- /dev/null +++ b/tutorials-v5/quantum-optimal-control/Single_Qubit_RL.ipynb @@ -0,0 +1,1627 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c3371bed", + "metadata": {}, + "source": [ + "# Quantum Optimal Control with Reinforcement Learning\n", + "\n", + "In this notebook, we will demonstrate how to use the `_RL` module to solve a quantum optimal control problem using reinforcement learning (RL). We will define a simple state transfer problem with a single qubit, where the goal is to transfer a quantum system from one state to another, and we will use the RL agent to optimize the control pulses to achieve this task.\n", + "After we will also see the same problem but using unitary operators" + ] + }, + { + "cell_type": "markdown", + "id": "717f07c0-515d-427a-82ef-14603780b9ff", + "metadata": {}, + "source": [ + "## State to State Transfer\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "c81d453d", + "metadata": {}, + "source": [ + "### Setup and Import Required Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f1e015e6", + "metadata": {}, + "outputs": [], + "source": [ + "# If you are running this in an environment where some packages are missing, use this cell to install them:\n", + "# !pip install qutip stable-baselines3 gymnasium\n", + "\n", + "import qutip as qt\n", + "import numpy as np\n", + "from stable_baselines3 import PPO\n", + "#from qutip_qoc import Result, Objective, _TimeInterval\n", + "from qutip_qoc import Objective\n", + "#from _rl import _RL\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0131f1b2-b218-435c-a1e6-5c1acad6ac6b", + "metadata": {}, + "outputs": [], + "source": [ + "#this is just for using local files (not yet merged in github)\n", + "import sys\n", + "import os\n", + "\n", + "module_path = os.path.abspath(os.path.join('..', 'Github', 'qutip-qoc', 'src', 'qutip_qoc'))\n", + "\n", + "sys.path.append(module_path)\n", + "\n", + "from _rl import _RL\n", + "from pulse_optim import optimize_pulses" + ] + }, + { + "cell_type": "markdown", + "id": "b4b725c0", + "metadata": {}, + "source": [ + "### Define the Quantum Control Problem" + ] + }, + { + "cell_type": "markdown", + "id": "7e895742", + "metadata": {}, + "source": [ + "We define the problem of transferring a quantum system from the initial state |0⟩ to the target state |+⟩. The system is controlled via three control Hamiltonians corresponding to the Pauli matrices, and a drift Hamiltonian for natural evolution of the qubit." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6c414871", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the initial and target states\n", + "initial_state = qt.basis(2, 0) # |0⟩\n", + "target_state = (qt.basis(2, 0) + qt.basis(2, 1)).unit() # |+⟩\n", + "#target_state = qt.basis(2, 1) # |1⟩\n", + "\n", + "# Define the control Hamiltonians (Pauli matrices)\n", + "H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()]\n", + "\n", + "# Define the drift Hamiltonian\n", + "w, d = 0.1, 1.0\n", + "H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax())\n", + "\n", + "# Combine the Hamiltonians into a single list\n", + "H = [H_d] + H_c\n", + "\n", + "# Define the objective\n", + "objectives = [Objective(initial=initial_state, H=H, target=target_state)]\n", + "\n", + "# Define the control parameters with bounds\n", + "control_parameters = {\n", + " \"p\": {\"bounds\": [(-13, 13)]}\n", + "}\n", + "\n", + "# Define the time interval\n", + "tlist = np.linspace(0, 10, 100)\n", + "\n", + "# Define algorithm-specific settings\n", + "algorithm_kwargs = {\n", + " \"fid_err_targ\": 0.01,\n", + " \"alg\": \"RL\",\n", + " \"max_iter\": 20000,\n", + " \"shorter_pulses\": True,\n", + "}\n", + "optimizer_kwargs = {}\n" + ] + }, + { + "cell_type": "markdown", + "id": "eab0e4f0-3b42-4f32-a3dc-bb752c8dbe8c", + "metadata": {}, + "source": [ + "Note that `max_iter` defines the number of episodes, the 100 in `tlist` defines the maximum number of steps per episode. \n", + "If `shorter_pulses` is True, the training will be longer as the algorithm will try to optimize the episodes using as few steps as possible in addition to checking if the target infidelity is reached.\n", + "If it is False, the algorithm stops as soon as it finds an episode with infidelity <= of the target infidelity" + ] + }, + { + "cell_type": "markdown", + "id": "a6273e17", + "metadata": {}, + "source": [ + "### Initialize and Train the RL Environment" + ] + }, + { + "cell_type": "markdown", + "id": "e7b27df4", + "metadata": {}, + "source": [ + "Now we will call the `optimize_pulses()` method, passing it the control problem we defined.\n", + "The method will create an instance of the `_RL` class, which will set up the reinforcement learning environment and start training.\n", + "Finally it returns the optimization results through an object of the `Result` class." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1c4b0b58", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cpu device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n", + "---------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 9.92 |\n", + "| ep_rew_mean | -2.21 |\n", + "| time/ | |\n", + "| fps | 2282 |\n", + "| iterations | 1 |\n", + "| time_elapsed | 0 |\n", + "| total_timesteps | 2048 |\n", + "---------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 10.1 |\n", + "| ep_rew_mean | -2.3 |\n", + "| time/ | |\n", + "| fps | 1871 |\n", + "| iterations | 2 |\n", + "| time_elapsed | 2 |\n", + "| total_timesteps | 4096 |\n", + "| train/ | |\n", + "| approx_kl | 0.005770128 |\n", + "| clip_fraction | 0.0454 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.26 |\n", + "| explained_variance | 0.00456 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.36 |\n", + "| n_updates | 10 |\n", + "| policy_gradient_loss | -0.00512 |\n", + "| std | 1 |\n", + "| value_loss | 3.56 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 8.16 |\n", + "| ep_rew_mean | -1.67 |\n", + "| time/ | |\n", + "| fps | 1777 |\n", + "| iterations | 3 |\n", + "| time_elapsed | 3 |\n", + "| total_timesteps | 6144 |\n", + "| train/ | |\n", + "| approx_kl | 0.0059958715 |\n", + "| clip_fraction | 0.0327 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.28 |\n", + "| explained_variance | 0.00907 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.92 |\n", + "| n_updates | 20 |\n", + "| policy_gradient_loss | -0.00476 |\n", + "| std | 1.01 |\n", + "| value_loss | 5.02 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 10.2 |\n", + "| ep_rew_mean | -2.21 |\n", + "| time/ | |\n", + "| fps | 1707 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 4 |\n", + "| total_timesteps | 8192 |\n", + "| train/ | |\n", + "| approx_kl | 0.0067452346 |\n", + "| clip_fraction | 0.0616 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.28 |\n", + "| explained_variance | 0.0347 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 2.78 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.00704 |\n", + "| std | 1.01 |\n", + "| value_loss | 6.3 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6.49 |\n", + "| ep_rew_mean | -1.28 |\n", + "| time/ | |\n", + "| fps | 1685 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 6 |\n", + "| total_timesteps | 10240 |\n", + "| train/ | |\n", + "| approx_kl | 0.005527503 |\n", + "| clip_fraction | 0.0367 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.26 |\n", + "| explained_variance | 0.0404 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 2.51 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.00418 |\n", + "| std | 0.998 |\n", + "| value_loss | 5.73 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6.65 |\n", + "| ep_rew_mean | -1.37 |\n", + "| time/ | |\n", + "| fps | 1667 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 7 |\n", + "| total_timesteps | 12288 |\n", + "| train/ | |\n", + "| approx_kl | 0.006008967 |\n", + "| clip_fraction | 0.0488 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.25 |\n", + "| explained_variance | 0.0647 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 3.31 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.00528 |\n", + "| std | 1 |\n", + "| value_loss | 6.49 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5.92 |\n", + "| ep_rew_mean | -1.08 |\n", + "| time/ | |\n", + "| fps | 1656 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 8 |\n", + "| total_timesteps | 14336 |\n", + "| train/ | |\n", + "| approx_kl | 0.006880779 |\n", + "| clip_fraction | 0.0687 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.25 |\n", + "| explained_variance | 0.0296 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 3.2 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.00615 |\n", + "| std | 0.995 |\n", + "| value_loss | 5.22 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6.46 |\n", + "| ep_rew_mean | -1.21 |\n", + "| time/ | |\n", + "| fps | 1636 |\n", + "| iterations | 8 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 16384 |\n", + "| train/ | |\n", + "| approx_kl | 0.006914506 |\n", + "| clip_fraction | 0.0744 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.24 |\n", + "| explained_variance | 0.105 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 2.35 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.00785 |\n", + "| std | 0.996 |\n", + "| value_loss | 6.24 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 7.63 |\n", + "| ep_rew_mean | -1.32 |\n", + "| time/ | |\n", + "| fps | 1633 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 11 |\n", + "| total_timesteps | 18432 |\n", + "| train/ | |\n", + "| approx_kl | 0.008130429 |\n", + "| clip_fraction | 0.0882 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.24 |\n", + "| explained_variance | 0.137 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.93 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.00804 |\n", + "| std | 0.997 |\n", + "| value_loss | 5.68 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6.73 |\n", + "| ep_rew_mean | -1.21 |\n", + "| time/ | |\n", + "| fps | 1630 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 12 |\n", + "| total_timesteps | 20480 |\n", + "| train/ | |\n", + "| approx_kl | 0.009801803 |\n", + "| clip_fraction | 0.116 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.25 |\n", + "| explained_variance | -0.0604 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.8 |\n", + "| n_updates | 90 |\n", + "| policy_gradient_loss | -0.0123 |\n", + "| std | 1 |\n", + "| value_loss | 3.45 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 4.02 |\n", + "| ep_rew_mean | -0.668 |\n", + "| time/ | |\n", + "| fps | 1617 |\n", + "| iterations | 11 |\n", + "| time_elapsed | 13 |\n", + "| total_timesteps | 22528 |\n", + "| train/ | |\n", + "| approx_kl | 0.0070579182 |\n", + "| clip_fraction | 0.0863 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.24 |\n", + "| explained_variance | 0.0842 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.38 |\n", + "| n_updates | 100 |\n", + "| policy_gradient_loss | -0.0104 |\n", + "| std | 0.991 |\n", + "| value_loss | 3.32 |\n", + "------------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 4.29 |\n", + "| ep_rew_mean | -0.666 |\n", + "| time/ | |\n", + "| fps | 1615 |\n", + "| iterations | 12 |\n", + "| time_elapsed | 15 |\n", + "| total_timesteps | 24576 |\n", + "| train/ | |\n", + "| approx_kl | 0.01636789 |\n", + "| clip_fraction | 0.165 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.22 |\n", + "| explained_variance | 0.104 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.37 |\n", + "| n_updates | 110 |\n", + "| policy_gradient_loss | -0.021 |\n", + "| std | 0.986 |\n", + "| value_loss | 2.7 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 2.32 |\n", + "| ep_rew_mean | -0.258 |\n", + "| time/ | |\n", + "| fps | 1611 |\n", + "| iterations | 13 |\n", + "| time_elapsed | 16 |\n", + "| total_timesteps | 26624 |\n", + "| train/ | |\n", + "| approx_kl | 0.011109401 |\n", + "| clip_fraction | 0.12 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.18 |\n", + "| explained_variance | 0.111 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.735 |\n", + "| n_updates | 120 |\n", + "| policy_gradient_loss | -0.0191 |\n", + "| std | 0.965 |\n", + "| value_loss | 1.98 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.68 |\n", + "| ep_rew_mean | -0.14 |\n", + "| time/ | |\n", + "| fps | 1601 |\n", + "| iterations | 14 |\n", + "| time_elapsed | 17 |\n", + "| total_timesteps | 28672 |\n", + "| train/ | |\n", + "| approx_kl | 0.012810201 |\n", + "| clip_fraction | 0.164 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.12 |\n", + "| explained_variance | 0.199 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.07 |\n", + "| n_updates | 130 |\n", + "| policy_gradient_loss | -0.0239 |\n", + "| std | 0.948 |\n", + "| value_loss | 1.78 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.57 |\n", + "| ep_rew_mean | -0.128 |\n", + "| time/ | |\n", + "| fps | 1599 |\n", + "| iterations | 15 |\n", + "| time_elapsed | 19 |\n", + "| total_timesteps | 30720 |\n", + "| train/ | |\n", + "| approx_kl | 0.011240302 |\n", + "| clip_fraction | 0.151 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.06 |\n", + "| explained_variance | 0.289 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.693 |\n", + "| n_updates | 140 |\n", + "| policy_gradient_loss | -0.0189 |\n", + "| std | 0.931 |\n", + "| value_loss | 1.54 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.28 |\n", + "| ep_rew_mean | -0.0604 |\n", + "| time/ | |\n", + "| fps | 1596 |\n", + "| iterations | 16 |\n", + "| time_elapsed | 20 |\n", + "| total_timesteps | 32768 |\n", + "| train/ | |\n", + "| approx_kl | 0.012472793 |\n", + "| clip_fraction | 0.148 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.02 |\n", + "| explained_variance | 0.419 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.593 |\n", + "| n_updates | 150 |\n", + "| policy_gradient_loss | -0.021 |\n", + "| std | 0.924 |\n", + "| value_loss | 0.859 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.43 |\n", + "| ep_rew_mean | -0.0818 |\n", + "| time/ | |\n", + "| fps | 1587 |\n", + "| iterations | 17 |\n", + "| time_elapsed | 21 |\n", + "| total_timesteps | 34816 |\n", + "| train/ | |\n", + "| approx_kl | 0.010671131 |\n", + "| clip_fraction | 0.121 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.97 |\n", + "| explained_variance | 0.407 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.159 |\n", + "| n_updates | 160 |\n", + "| policy_gradient_loss | -0.0216 |\n", + "| std | 0.904 |\n", + "| value_loss | 0.472 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.44 |\n", + "| ep_rew_mean | -0.073 |\n", + "| time/ | |\n", + "| fps | 1584 |\n", + "| iterations | 18 |\n", + "| time_elapsed | 23 |\n", + "| total_timesteps | 36864 |\n", + "| train/ | |\n", + "| approx_kl | 0.009028839 |\n", + "| clip_fraction | 0.116 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.9 |\n", + "| explained_variance | 0.376 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0734 |\n", + "| n_updates | 170 |\n", + "| policy_gradient_loss | -0.0181 |\n", + "| std | 0.883 |\n", + "| value_loss | 0.225 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.25 |\n", + "| ep_rew_mean | -0.0551 |\n", + "| time/ | |\n", + "| fps | 1581 |\n", + "| iterations | 19 |\n", + "| time_elapsed | 24 |\n", + "| total_timesteps | 38912 |\n", + "| train/ | |\n", + "| approx_kl | 0.0109791495 |\n", + "| clip_fraction | 0.126 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.87 |\n", + "| explained_variance | 0.443 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0107 |\n", + "| n_updates | 180 |\n", + "| policy_gradient_loss | -0.0189 |\n", + "| std | 0.885 |\n", + "| value_loss | 0.204 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.08 |\n", + "| ep_rew_mean | -0.0231 |\n", + "| time/ | |\n", + "| fps | 1574 |\n", + "| iterations | 20 |\n", + "| time_elapsed | 26 |\n", + "| total_timesteps | 40960 |\n", + "| train/ | |\n", + "| approx_kl | 0.010651594 |\n", + "| clip_fraction | 0.0928 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.84 |\n", + "| explained_variance | 0.49 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0843 |\n", + "| n_updates | 190 |\n", + "| policy_gradient_loss | -0.0125 |\n", + "| std | 0.866 |\n", + "| value_loss | 0.231 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.01 |\n", + "| ep_rew_mean | -0.0072 |\n", + "| time/ | |\n", + "| fps | 1571 |\n", + "| iterations | 21 |\n", + "| time_elapsed | 27 |\n", + "| total_timesteps | 43008 |\n", + "| train/ | |\n", + "| approx_kl | 0.007608075 |\n", + "| clip_fraction | 0.0848 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.77 |\n", + "| explained_variance | 0.392 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0338 |\n", + "| n_updates | 200 |\n", + "| policy_gradient_loss | -0.016 |\n", + "| std | 0.846 |\n", + "| value_loss | 0.0777 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1.27 |\n", + "| ep_rew_mean | -0.0515 |\n", + "| time/ | |\n", + "| fps | 1567 |\n", + "| iterations | 22 |\n", + "| time_elapsed | 28 |\n", + "| total_timesteps | 45056 |\n", + "| train/ | |\n", + "| approx_kl | 0.009603938 |\n", + "| clip_fraction | 0.0639 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.71 |\n", + "| explained_variance | 0.435 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0436 |\n", + "| n_updates | 210 |\n", + "| policy_gradient_loss | -0.0129 |\n", + "| std | 0.835 |\n", + "| value_loss | 0.0349 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 1 |\n", + "| ep_rew_mean | -0.00487 |\n", + "| time/ | |\n", + "| fps | 1561 |\n", + "| iterations | 23 |\n", + "| time_elapsed | 30 |\n", + "| total_timesteps | 47104 |\n", + "| train/ | |\n", + "| approx_kl | 0.008504309 |\n", + "| clip_fraction | 0.0841 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.71 |\n", + "| explained_variance | 0.558 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0363 |\n", + "| n_updates | 220 |\n", + "| policy_gradient_loss | -0.0113 |\n", + "| std | 0.842 |\n", + "| value_loss | 0.177 |\n", + "-----------------------------------------\n" + ] + } + ], + "source": [ + "# Initialize the RL environment and start training\n", + "rl_result = optimize_pulses(\n", + " objectives,\n", + " control_parameters,\n", + " tlist,\n", + " algorithm_kwargs,\n", + " optimizer_kwargs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b5fa66a4", + "metadata": {}, + "source": [ + "### Analyze the Results" + ] + }, + { + "cell_type": "markdown", + "id": "afdc0129", + "metadata": {}, + "source": [ + "After the training is complete, we can analyze the results obtained by the RL agent. \n", + "In the above window showing the output produced by Gymansium, you can observe how during training the number of steps per episode (ep_len_mean) decreases and the average reward of the episodes (ep_rew_mean) increases." + ] + }, + { + "cell_type": "markdown", + "id": "3172a7d8-5621-4366-a7a1-847aa01974cf", + "metadata": {}, + "source": [ + "We can now see the fields of the `Result` class, this includes the final infidelity, the optimized control parameters and more." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "32521c12-5ff1-4199-a9d2-bc56bfd58b51", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Control Optimization Result\n", + "--------------------------\n", + "- Started at 2024-08-24 20:38:52\n", + "- Number of objectives: 1\n", + "- Final fidelity error: 0.0048685058051837204\n", + "- Final parameters: [[-10.849769413471222, -12.63647347688675, -13.0], 30.0]\n", + "- Number of iterations: 19691\n", + "- Reason for termination: Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid.\n", + "- Optimized time parameter: 30.0\n", + "- Ended at 2024-08-24 20:39:22 (30.0s)\n" + ] + } + ], + "source": [ + "print(rl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "340f0cc9-8329-47d3-9cba-09009b56a786", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# We can also visualize the initial and final states on the Bloch sphere\n", + "bloch_sp = qt.Bloch()\n", + "bloch_sp.add_states(initial_state)\n", + "bloch_sp.add_states(target_state)\n", + "bloch_sp.add_states(rl_result._final_states[0])\n", + "bloch_sp.show()" + ] + }, + { + "cell_type": "markdown", + "id": "abd99882-1f74-4036-aa66-2aff99758ded", + "metadata": {}, + "source": [ + "If the total number of iterations in the Result class is slightly higher than the set value, it is because the algorithm needs to complete the rollout, which consists of a certain number of episodes, before performing termination checks (as defined in the Callback class)." + ] + }, + { + "cell_type": "markdown", + "id": "a4ac563d", + "metadata": {}, + "source": [ + "## Unitary Operators\n", + "\n", + "Now we will show how to tackle a problem similar to the previous one, but this time, instead of reaching a specific target state, the goal is to start from the identity operator and evolve it in a controlled way until we obtain a specific unitary operator, such as the Hadamard gate." + ] + }, + { + "cell_type": "markdown", + "id": "0c2f5b7e-9f73-498e-bd61-fe51d9eb2879", + "metadata": {}, + "source": [ + "The control problem is similar to the previous one, we just need to change the initial state, the target state (now they are matrices) and update the objective. \n", + "We can also change the number of episodes for this task by changing `max_iter` \n", + "By setting `shorter_pulses` to False, the algorithm will stop as soon as it finds an episode that satisfies the target infidelity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3aeeb5bd-e637-4af7-b0a7-7196c24474ed", + "metadata": {}, + "outputs": [], + "source": [ + "initial = qt.qeye(2) # Identity\n", + "target = qt.gates.hadamard_transform()\n", + "\n", + "objectives=[Objective(initial, H, target)]\n", + "\n", + "algorithm_kwargs = {\n", + " \"fid_err_targ\": 0.01,\n", + " \"alg\": \"RL\",\n", + " \"max_iter\": 900,\n", + " \"shorter_pulses\": False,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6e7a9f9d-9761-457e-994c-732b11060d73", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cpu device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n", + "---------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.4 |\n", + "| time/ | |\n", + "| fps | 2351 |\n", + "| iterations | 1 |\n", + "| time_elapsed | 0 |\n", + "| total_timesteps | 2048 |\n", + "---------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.3 |\n", + "| time/ | |\n", + "| fps | 1878 |\n", + "| iterations | 2 |\n", + "| time_elapsed | 2 |\n", + "| total_timesteps | 4096 |\n", + "| train/ | |\n", + "| approx_kl | 0.004405604 |\n", + "| clip_fraction | 0.0345 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.25 |\n", + "| explained_variance | -0.0302 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.86 |\n", + "| n_updates | 10 |\n", + "| policy_gradient_loss | -0.00376 |\n", + "| std | 0.997 |\n", + "| value_loss | 28 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.1 |\n", + "| time/ | |\n", + "| fps | 1734 |\n", + "| iterations | 3 |\n", + "| time_elapsed | 3 |\n", + "| total_timesteps | 6144 |\n", + "| train/ | |\n", + "| approx_kl | 0.0046608616 |\n", + "| clip_fraction | 0.0235 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.25 |\n", + "| explained_variance | 0.00253 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.58 |\n", + "| n_updates | 20 |\n", + "| policy_gradient_loss | -0.00136 |\n", + "| std | 0.995 |\n", + "| value_loss | 8.41 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.8 |\n", + "| time/ | |\n", + "| fps | 1692 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 4 |\n", + "| total_timesteps | 8192 |\n", + "| train/ | |\n", + "| approx_kl | 0.006713907 |\n", + "| clip_fraction | 0.0322 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.23 |\n", + "| explained_variance | 0.0167 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.87 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.00373 |\n", + "| std | 0.984 |\n", + "| value_loss | 7.91 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.3 |\n", + "| time/ | |\n", + "| fps | 1675 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 6 |\n", + "| total_timesteps | 10240 |\n", + "| train/ | |\n", + "| approx_kl | 0.0056952434 |\n", + "| clip_fraction | 0.0311 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.2 |\n", + "| explained_variance | 0.0864 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.2 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.00337 |\n", + "| std | 0.974 |\n", + "| value_loss | 5.18 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49 |\n", + "| time/ | |\n", + "| fps | 1658 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 7 |\n", + "| total_timesteps | 12288 |\n", + "| train/ | |\n", + "| approx_kl | 0.0038293994 |\n", + "| clip_fraction | 0.0164 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.16 |\n", + "| explained_variance | 0.0966 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.13 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.00192 |\n", + "| std | 0.962 |\n", + "| value_loss | 4.51 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.3 |\n", + "| time/ | |\n", + "| fps | 1649 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 8 |\n", + "| total_timesteps | 14336 |\n", + "| train/ | |\n", + "| approx_kl | 0.008284866 |\n", + "| clip_fraction | 0.0638 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.14 |\n", + "| explained_variance | 0.232 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.77 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.00626 |\n", + "| std | 0.963 |\n", + "| value_loss | 2.74 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.1 |\n", + "| time/ | |\n", + "| fps | 1633 |\n", + "| iterations | 8 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 16384 |\n", + "| train/ | |\n", + "| approx_kl | 0.0061374204 |\n", + "| clip_fraction | 0.0455 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.14 |\n", + "| explained_variance | 0.37 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.561 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.00462 |\n", + "| std | 0.96 |\n", + "| value_loss | 2.08 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49 |\n", + "| time/ | |\n", + "| fps | 1630 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 11 |\n", + "| total_timesteps | 18432 |\n", + "| train/ | |\n", + "| approx_kl | 0.0043069483 |\n", + "| clip_fraction | 0.0229 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.14 |\n", + "| explained_variance | 0.444 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.794 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.00385 |\n", + "| std | 0.966 |\n", + "| value_loss | 2.14 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.5 |\n", + "| time/ | |\n", + "| fps | 1627 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 12 |\n", + "| total_timesteps | 20480 |\n", + "| train/ | |\n", + "| approx_kl | 0.004516308 |\n", + "| clip_fraction | 0.0264 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.15 |\n", + "| explained_variance | 0.656 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.701 |\n", + "| n_updates | 90 |\n", + "| policy_gradient_loss | -0.00309 |\n", + "| std | 0.962 |\n", + "| value_loss | 1.97 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.5 |\n", + "| time/ | |\n", + "| fps | 1623 |\n", + "| iterations | 11 |\n", + "| time_elapsed | 13 |\n", + "| total_timesteps | 22528 |\n", + "| train/ | |\n", + "| approx_kl | 0.004935258 |\n", + "| clip_fraction | 0.0244 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.14 |\n", + "| explained_variance | 0.669 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.803 |\n", + "| n_updates | 100 |\n", + "| policy_gradient_loss | -0.00345 |\n", + "| std | 0.962 |\n", + "| value_loss | 2.49 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.8 |\n", + "| time/ | |\n", + "| fps | 1609 |\n", + "| iterations | 12 |\n", + "| time_elapsed | 15 |\n", + "| total_timesteps | 24576 |\n", + "| train/ | |\n", + "| approx_kl | 0.0066686976 |\n", + "| clip_fraction | 0.0445 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.15 |\n", + "| explained_variance | 0.64 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.789 |\n", + "| n_updates | 110 |\n", + "| policy_gradient_loss | -0.00489 |\n", + "| std | 0.967 |\n", + "| value_loss | 2.6 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.4 |\n", + "| time/ | |\n", + "| fps | 1609 |\n", + "| iterations | 13 |\n", + "| time_elapsed | 16 |\n", + "| total_timesteps | 26624 |\n", + "| train/ | |\n", + "| approx_kl | 0.005435321 |\n", + "| clip_fraction | 0.0443 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.14 |\n", + "| explained_variance | 0.644 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.918 |\n", + "| n_updates | 120 |\n", + "| policy_gradient_loss | -0.00518 |\n", + "| std | 0.958 |\n", + "| value_loss | 1.98 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.5 |\n", + "| time/ | |\n", + "| fps | 1607 |\n", + "| iterations | 14 |\n", + "| time_elapsed | 17 |\n", + "| total_timesteps | 28672 |\n", + "| train/ | |\n", + "| approx_kl | 0.004944969 |\n", + "| clip_fraction | 0.0457 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.12 |\n", + "| explained_variance | 0.737 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.21 |\n", + "| n_updates | 130 |\n", + "| policy_gradient_loss | -0.00603 |\n", + "| std | 0.951 |\n", + "| value_loss | 2.28 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.9 |\n", + "| time/ | |\n", + "| fps | 1605 |\n", + "| iterations | 15 |\n", + "| time_elapsed | 19 |\n", + "| total_timesteps | 30720 |\n", + "| train/ | |\n", + "| approx_kl | 0.006223672 |\n", + "| clip_fraction | 0.0668 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.11 |\n", + "| explained_variance | 0.723 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.666 |\n", + "| n_updates | 140 |\n", + "| policy_gradient_loss | -0.00688 |\n", + "| std | 0.951 |\n", + "| value_loss | 1.98 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.1 |\n", + "| time/ | |\n", + "| fps | 1597 |\n", + "| iterations | 16 |\n", + "| time_elapsed | 20 |\n", + "| total_timesteps | 32768 |\n", + "| train/ | |\n", + "| approx_kl | 0.005545385 |\n", + "| clip_fraction | 0.0473 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.1 |\n", + "| explained_variance | 0.781 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.05 |\n", + "| n_updates | 150 |\n", + "| policy_gradient_loss | -0.00429 |\n", + "| std | 0.947 |\n", + "| value_loss | 2.5 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.9 |\n", + "| time/ | |\n", + "| fps | 1595 |\n", + "| iterations | 17 |\n", + "| time_elapsed | 21 |\n", + "| total_timesteps | 34816 |\n", + "| train/ | |\n", + "| approx_kl | 0.00498241 |\n", + "| clip_fraction | 0.0263 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.09 |\n", + "| explained_variance | 0.633 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.906 |\n", + "| n_updates | 160 |\n", + "| policy_gradient_loss | -0.00302 |\n", + "| std | 0.942 |\n", + "| value_loss | 2.75 |\n", + "----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -49.1 |\n", + "| time/ | |\n", + "| fps | 1594 |\n", + "| iterations | 18 |\n", + "| time_elapsed | 23 |\n", + "| total_timesteps | 36864 |\n", + "| train/ | |\n", + "| approx_kl | 0.0064855358 |\n", + "| clip_fraction | 0.0426 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.08 |\n", + "| explained_variance | 0.658 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.2 |\n", + "| n_updates | 170 |\n", + "| policy_gradient_loss | -0.00508 |\n", + "| std | 0.942 |\n", + "| value_loss | 2.36 |\n", + "------------------------------------------\n", + "---------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.8 |\n", + "| time/ | |\n", + "| fps | 1595 |\n", + "| iterations | 19 |\n", + "| time_elapsed | 24 |\n", + "| total_timesteps | 38912 |\n", + "| train/ | |\n", + "| approx_kl | 0.0066174 |\n", + "| clip_fraction | 0.0362 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.07 |\n", + "| explained_variance | 0.649 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.33 |\n", + "| n_updates | 180 |\n", + "| policy_gradient_loss | -0.00309 |\n", + "| std | 0.937 |\n", + "| value_loss | 3.01 |\n", + "---------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.5 |\n", + "| time/ | |\n", + "| fps | 1592 |\n", + "| iterations | 20 |\n", + "| time_elapsed | 25 |\n", + "| total_timesteps | 40960 |\n", + "| train/ | |\n", + "| approx_kl | 0.0062293075 |\n", + "| clip_fraction | 0.0578 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.06 |\n", + "| explained_variance | 0.683 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.38 |\n", + "| n_updates | 190 |\n", + "| policy_gradient_loss | -0.00481 |\n", + "| std | 0.937 |\n", + "| value_loss | 2.4 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.1 |\n", + "| time/ | |\n", + "| fps | 1593 |\n", + "| iterations | 21 |\n", + "| time_elapsed | 26 |\n", + "| total_timesteps | 43008 |\n", + "| train/ | |\n", + "| approx_kl | 0.006579734 |\n", + "| clip_fraction | 0.0502 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.06 |\n", + "| explained_variance | 0.727 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.28 |\n", + "| n_updates | 200 |\n", + "| policy_gradient_loss | -0.00512 |\n", + "| std | 0.934 |\n", + "| value_loss | 2.63 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.9 |\n", + "| time/ | |\n", + "| fps | 1595 |\n", + "| iterations | 22 |\n", + "| time_elapsed | 28 |\n", + "| total_timesteps | 45056 |\n", + "| train/ | |\n", + "| approx_kl | 0.0065193363 |\n", + "| clip_fraction | 0.0587 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4.03 |\n", + "| explained_variance | 0.702 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.4 |\n", + "| n_updates | 210 |\n", + "| policy_gradient_loss | -0.00485 |\n", + "| std | 0.923 |\n", + "| value_loss | 2.38 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -48.2 |\n", + "| time/ | |\n", + "| fps | 1592 |\n", + "| iterations | 23 |\n", + "| time_elapsed | 29 |\n", + "| total_timesteps | 47104 |\n", + "| train/ | |\n", + "| approx_kl | 0.0045397887 |\n", + "| clip_fraction | 0.0393 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -4 |\n", + "| explained_variance | 0.757 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.838 |\n", + "| n_updates | 220 |\n", + "| policy_gradient_loss | -0.00353 |\n", + "| std | 0.918 |\n", + "| value_loss | 2.1 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.9 |\n", + "| time/ | |\n", + "| fps | 1594 |\n", + "| iterations | 24 |\n", + "| time_elapsed | 30 |\n", + "| total_timesteps | 49152 |\n", + "| train/ | |\n", + "| approx_kl | 0.0059773056 |\n", + "| clip_fraction | 0.0457 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.99 |\n", + "| explained_variance | 0.713 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.11 |\n", + "| n_updates | 230 |\n", + "| policy_gradient_loss | -0.00375 |\n", + "| std | 0.916 |\n", + "| value_loss | 2.91 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.7 |\n", + "| time/ | |\n", + "| fps | 1592 |\n", + "| iterations | 25 |\n", + "| time_elapsed | 32 |\n", + "| total_timesteps | 51200 |\n", + "| train/ | |\n", + "| approx_kl | 0.006354755 |\n", + "| clip_fraction | 0.0526 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.97 |\n", + "| explained_variance | 0.692 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.18 |\n", + "| n_updates | 240 |\n", + "| policy_gradient_loss | -0.00395 |\n", + "| std | 0.905 |\n", + "| value_loss | 2.68 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.9 |\n", + "| time/ | |\n", + "| fps | 1592 |\n", + "| iterations | 26 |\n", + "| time_elapsed | 33 |\n", + "| total_timesteps | 53248 |\n", + "| train/ | |\n", + "| approx_kl | 0.0069060894 |\n", + "| clip_fraction | 0.0633 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.95 |\n", + "| explained_variance | 0.797 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.16 |\n", + "| n_updates | 250 |\n", + "| policy_gradient_loss | -0.00548 |\n", + "| std | 0.902 |\n", + "| value_loss | 2.56 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.8 |\n", + "| time/ | |\n", + "| fps | 1592 |\n", + "| iterations | 27 |\n", + "| time_elapsed | 34 |\n", + "| total_timesteps | 55296 |\n", + "| train/ | |\n", + "| approx_kl | 0.008833875 |\n", + "| clip_fraction | 0.0865 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.93 |\n", + "| explained_variance | 0.788 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.18 |\n", + "| n_updates | 260 |\n", + "| policy_gradient_loss | -0.00951 |\n", + "| std | 0.895 |\n", + "| value_loss | 2.5 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.3 |\n", + "| time/ | |\n", + "| fps | 1588 |\n", + "| iterations | 28 |\n", + "| time_elapsed | 36 |\n", + "| total_timesteps | 57344 |\n", + "| train/ | |\n", + "| approx_kl | 0.009341501 |\n", + "| clip_fraction | 0.107 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.91 |\n", + "| explained_variance | 0.795 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.16 |\n", + "| n_updates | 270 |\n", + "| policy_gradient_loss | -0.00948 |\n", + "| std | 0.89 |\n", + "| value_loss | 2.49 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -47.3 |\n", + "| time/ | |\n", + "| fps | 1588 |\n", + "| iterations | 29 |\n", + "| time_elapsed | 37 |\n", + "| total_timesteps | 59392 |\n", + "| train/ | |\n", + "| approx_kl | 0.00845156 |\n", + "| clip_fraction | 0.0839 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.88 |\n", + "| explained_variance | 0.778 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.15 |\n", + "| n_updates | 280 |\n", + "| policy_gradient_loss | -0.00978 |\n", + "| std | 0.878 |\n", + "| value_loss | 2.78 |\n", + "----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -46.9 |\n", + "| time/ | |\n", + "| fps | 1589 |\n", + "| iterations | 30 |\n", + "| time_elapsed | 38 |\n", + "| total_timesteps | 61440 |\n", + "| train/ | |\n", + "| approx_kl | 0.0066356836 |\n", + "| clip_fraction | 0.0696 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.84 |\n", + "| explained_variance | 0.799 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.38 |\n", + "| n_updates | 290 |\n", + "| policy_gradient_loss | -0.00576 |\n", + "| std | 0.865 |\n", + "| value_loss | 2.67 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -46.4 |\n", + "| time/ | |\n", + "| fps | 1589 |\n", + "| iterations | 31 |\n", + "| time_elapsed | 39 |\n", + "| total_timesteps | 63488 |\n", + "| train/ | |\n", + "| approx_kl | 0.011593084 |\n", + "| clip_fraction | 0.118 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -3.79 |\n", + "| explained_variance | 0.722 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.06 |\n", + "| n_updates | 300 |\n", + "| policy_gradient_loss | -0.0136 |\n", + "| std | 0.849 |\n", + "| value_loss | 2.73 |\n", + "-----------------------------------------\n" + ] + } + ], + "source": [ + "# Initialize the RL environment and start training\n", + "rl_result = optimize_pulses(\n", + " objectives,\n", + " control_parameters,\n", + " tlist,\n", + " algorithm_kwargs,\n", + " optimizer_kwargs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9d46b2cb-e3e4-492b-84f7-2229058ff802", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Control Optimization Result\n", + "--------------------------\n", + "- Started at 2024-08-26 19:04:39\n", + "- Number of objectives: 1\n", + "- Final fidelity error: 0.002551819763576102\n", + "- Final parameters: [[3.4188188314437866, 13.0, 13.0], [13.0, 13.0, -13.0], [6.511582612991333, 13.0, 8.964270651340485], [-13.0, 7.0920491218566895, 13.0], [13.0, 13.0, 7.403999388217926], [-7.155267655849457, 6.941469728946686, -0.06780803203582764], [0.3478652387857437, 13.0, -13.0], [0.9450445175170898, 13.0, 10.478427648544312], [-2.7053715586662292, 6.866261005401611, -2.9162583351135254], [-12.548729062080383, 13.0, 0.3780375272035599], [5.524807989597321, -12.404477953910828, -1.3450259938836098], [9.932345569133759, 6.742649078369141, 0.8036832511425018], [10.077520310878754, 10.17627239227295, -0.3489200174808502], [-8.5400710105896, -13.0, 4.9227414727211], [3.7715193927288055, 13.0, -13.0], [-9.494724571704865, -13.0, -12.165066242218018], [-13.0, 0.28259196877479553, -0.2347586750984192], [-13.0, 13.0, -1.1993653178215027], [-13.0, 13.0, -3.2808731496334076], [-2.0000401735305786, 13.0, 13.0], [-6.9624924659729, -0.8900302052497864, 13.0], [13.0, 0.22511941194534302, -13.0], [12.91728287935257, 13.0, 13.0], [13.0, 13.0, -13.0], [5.803696125745773, 9.111721158027649, 13.0], [13.0, 13.0, -7.657740592956543], [13.0, 13.0, -6.739334225654602], [13.0, -0.5586533173918724, 13.0], [-3.2091493606567383, 9.593821465969086, -13.0], [13.0, -13.0, 6.971161603927612], [13.0, -4.907809913158417, -4.819970577955246], [3.572271406650543, 9.932078242301941, 13.0], [13.0, -13.0, -4.883280158042908], [9.99718976020813, 13.0, 7.891545414924622], [1.17414128780365, 11.88118064403534, 13.0], [-13.0, 13.0, 1.8821984380483627], [9.180704653263092, -9.39372307062149, 2.399792581796646], [-3.331559866666794, -0.2218501791357994, 13.0], [3.0872934609651566, 13.0, -4.681073755025864], [-0.13288933038711548, -6.234242260456085, 1.472162663936615], [3.661263346672058, 1.5755969360470772, -12.517065167427063], [1.5426646918058395, -4.089141249656677, 13.0], [4.887776285409927, -13.0, -10.995292067527771], [-13.0, -0.45841602236032486, 13.0], [3.2520944476127625, 3.0989434868097305, 6.663681507110596], [-13.0, -6.065039873123169, 11.589093923568726], [1.3966505825519562, 13.0, 1.652427151799202], [-13.0, -1.0330521911382675, 13.0], [1.4274927377700806, 13.0, 8.263666689395905], [-6.012888669967651, -1.4705532789230347, -9.972679376602173], [-7.9508408308029175, -11.454957842826843, -8.144261479377747], [2.247176080942154, 1.2720030546188354, 13.0], [13.0, 13.0, -11.275928676128387], [-0.48115187883377075, -11.333942472934723, 7.231829285621643], [2.05837282538414, -12.688138842582703, -13.0], [-7.996112048625946, -8.991393089294434, -7.918865442276001], [-13.0, 10.492306172847748, 3.85249462723732], [6.597911357879639, -13.0, 12.52410089969635], [13.0, -1.6222021728754044, 2.7806494385004044], [7.98918479681015, 5.118352651596069, -3.199215844273567], [13.0, -13.0, -11.10127592086792], [-0.3817961812019348, 13.0, -5.7455726861953735], [-13.0, 3.0795744955539703, 13.0], [-2.201941668987274, -3.7015785574913025, -13.0], [0.7181568741798401, -7.438749551773071, 10.924750328063965], [13.0, 3.264466255903244, -1.135591208934784], [13.0, -3.5049081444740295, 7.498368859291077], [13.0, -3.6533349752426147, -12.531098663806915], [-13.0, 6.673511385917664, 13.0], [-12.567341208457947, 9.32302713394165, 2.2332334369421005], [4.962125301361084, 10.5961754322052, -9.708646476268768], [2.4967450499534607, -9.764376759529114, 13.0], [-3.2077371776103973, 13.0, -1.0517201274633408], [-13.0, 13.0, -13.0], [-10.395427703857422, 11.5837721824646, 7.713315904140472], [-1.5034708976745605, -3.695146828889847, -9.910763382911682], [3.4931019842624664, -6.954262673854828, 13.0], [13.0, -4.445300787687302, -9.596865117549896], [13.0, -13.0, 11.766752362251282], [4.538477748632431, 13.0, 6.2513333559036255], [-6.342582851648331, 0.7486081123352051, 13.0], [-2.477048873901367, -4.575975179672241, -3.5929694771766663], [-7.137736439704895, -3.73245170712471, -2.05170476436615], [-9.719637095928192, 0.4852590262889862, 5.271544873714447], [11.546406865119934, -0.5641068816184998, 13.0], [13.0, -13.0, -1.4608764350414276], 42.0]\n", + "- Number of iterations: 641\n", + "- Reason for termination: Stop training because an episode with infidelity <= target infidelity was found\n", + "- Optimized time parameter: 42.0\n", + "- Ended at 2024-08-26 19:05:21 (42.0s)\n" + ] + } + ], + "source": [ + "print(rl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a8d8739f-1c54-432d-b2ae-ff1e71d217d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'hinton')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# We can show in this case the hinton matrix\n", + "fig, ax = qt.hinton(rl_result._final_states[0])\n", + "ax.set_title('hinton')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5022f446-a4de-4961-9e25-4e3b3488f2e7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# for the hadamard matrix\n", + "U = qt.gates.hadamard_transform()\n", + "fig, ax = qt.hinton(U)" + ] + }, + { + "cell_type": "markdown", + "id": "cababff9-b1f3-41ef-b0e9-d0fea38a9e4b", + "metadata": {}, + "source": [ + "We are using PSU norm in the infidelity calculation, so the found transformation is correct, independently of the global phase." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3812ffe-d30c-47f1-84a6-f33ddc76edea", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials-v5/quantum-optimal-control/Two_Qubits_RL.ipynb b/tutorials-v5/quantum-optimal-control/Two_Qubits_RL.ipynb new file mode 100644 index 0000000..182aef7 --- /dev/null +++ b/tutorials-v5/quantum-optimal-control/Two_Qubits_RL.ipynb @@ -0,0 +1,713 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c3371bed", + "metadata": {}, + "source": [ + "# Quantum Optimal Control with Reinforcement Learning\n", + "\n", + "In this notebook, we will demonstrate how to use the `_RL` module to solve a quantum optimal control problem using reinforcement learning (RL).\n", + "The goal is to use 2 Qubits to realize CNOT gate. In practice there is a control qubit and a target qubit, if the control qubit is in the state |0⟩ the target qubit remains unchanged, if the control qubit is in the state |1⟩ the CNOT gate flips the state of the target qubit.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c81d453d", + "metadata": {}, + "source": [ + "### Setup and Import Required Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f1e015e6", + "metadata": {}, + "outputs": [], + "source": [ + "# If you are running this in an environment where some packages are missing, use this cell to install them:\n", + "# !pip install qutip stable-baselines3 gymnasium\n", + "\n", + "import qutip as qt\n", + "import numpy as np\n", + "from stable_baselines3 import PPO\n", + "from qutip_qoc import Objective\n", + "#from _rl import _RL\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0131f1b2-b218-435c-a1e6-5c1acad6ac6b", + "metadata": {}, + "outputs": [], + "source": [ + "#this is just for using local files (not yet merged in github)\n", + "import sys\n", + "import os\n", + "\n", + "module_path = os.path.abspath(os.path.join('..', 'Github', 'qutip-qoc', 'src', 'qutip_qoc'))\n", + "\n", + "sys.path.append(module_path)\n", + "\n", + "from _rl import _RL\n", + "from pulse_optim import optimize_pulses" + ] + }, + { + "cell_type": "markdown", + "id": "b4b725c0", + "metadata": {}, + "source": [ + "### Define the Quantum Control Problem" + ] + }, + { + "cell_type": "markdown", + "id": "7e895742", + "metadata": {}, + "source": [ + "The system starts from an initial state represented by the identity on two qubits, with the goal of achieving a CNOT gate as the target state. To accomplish this, control operators based on the Pauli matrices are defined to act on individual qubits and pairs of qubits. Additionally, a drift Hamiltonian is introduced to account for interactions between the qubits and noise, thereby modeling the dynamics of the open quantum system." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6c414871", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the initial and target states\n", + "initial = qt.tensor(qt.qeye(2), qt.qeye(2))\n", + "target = qt.gates.cnot()\n", + "\n", + "# convert to superoperator (for open system)\n", + "initial = qt.sprepost(initial, initial.dag())\n", + "target = qt.sprepost(target, target.dag())\n", + "\n", + "# single qubit control operators\n", + "sx, sy, sz = qt.sigmax(), qt.sigmay(), qt.sigmaz()\n", + "identity = qt.qeye(2)\n", + "\n", + "# two qubit control operators\n", + "i_sx, sx_i = qt.tensor(sx, identity), qt.tensor(identity, sx)\n", + "i_sy, sy_i = qt.tensor(sy, identity), qt.tensor(identity, sy)\n", + "i_sz, sz_i = qt.tensor(sz, identity), qt.tensor(identity, sz)\n", + "\n", + "# Define the control Hamiltonians\n", + "Hc = [i_sx, i_sy, i_sz, sx_i, sy_i, sz_i]\n", + "Hc = [qt.liouvillian(H) for H in Hc]\n", + "\n", + "# drift and noise term for a two-qubit system\n", + "omega, delta, gamma = 0.1, 1.0, 0.1\n", + "i_sm, sm_i = qt.tensor(qt.sigmam(), identity), qt.tensor(identity, qt.sigmam())\n", + "\n", + "# energy levels and interaction\n", + "Hd = omega * (i_sz + sz_i) + delta * i_sz * sz_i\n", + "Hd = qt.liouvillian(H=Hd, c_ops=[gamma * (i_sm + sm_i)])\n", + "\n", + "# combined operator list\n", + "H = [Hd, Hc[0], Hc[1], Hc[2], Hc[3], Hc[4], Hc[5]]\n", + "\n", + "# Define the objective\n", + "objectives = [Objective(initial, H, target)]\n", + "\n", + "# Define the control parameters with bounds\n", + "control_parameters = {\n", + " \"p\": {\"bounds\": [(-30, 30)]}\n", + "}\n", + "\n", + "# Define the time interval\n", + "tlist = np.linspace(0, np.pi, 100)\n", + "\n", + "# Define algorithm-specific settings\n", + "algorithm_kwargs = {\n", + " \"fid_err_targ\": 0.01,\n", + " \"alg\": \"RL\",\n", + " \"max_iter\": 400,\n", + " \"shorter_pulses\": False,\n", + "}\n", + "optimizer_kwargs = {}\n" + ] + }, + { + "cell_type": "markdown", + "id": "eab0e4f0-3b42-4f32-a3dc-bb752c8dbe8c", + "metadata": {}, + "source": [ + "Note that `max_iter` defines the number of episodes, the 100 in `tlist` defines the maximum number of steps per episode" + ] + }, + { + "cell_type": "markdown", + "id": "a6273e17", + "metadata": {}, + "source": [ + "### Initialize and Train the RL Environment" + ] + }, + { + "cell_type": "markdown", + "id": "e7b27df4", + "metadata": {}, + "source": [ + "Now we will call the `optimize_pulses()` method, passing it the control problem we defined.\n", + "The method will create an instance of the `_RL` class, which will set up the reinforcement learning environment and start training.\n", + "Finally it returns the optimization results through an object of the `Result` class." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1c4b0b58", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cpu device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n", + "---------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 573 |\n", + "| iterations | 1 |\n", + "| time_elapsed | 3 |\n", + "| total_timesteps | 2048 |\n", + "---------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 520 |\n", + "| iterations | 2 |\n", + "| time_elapsed | 7 |\n", + "| total_timesteps | 4096 |\n", + "| train/ | |\n", + "| approx_kl | 0.009919285 |\n", + "| clip_fraction | 0.0753 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.51 |\n", + "| explained_variance | 0.0116 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.908 |\n", + "| n_updates | 10 |\n", + "| policy_gradient_loss | -0.0134 |\n", + "| std | 1 |\n", + "| value_loss | 36.2 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 501 |\n", + "| iterations | 3 |\n", + "| time_elapsed | 12 |\n", + "| total_timesteps | 6144 |\n", + "| train/ | |\n", + "| approx_kl | 0.008348271 |\n", + "| clip_fraction | 0.0497 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.51 |\n", + "| explained_variance | 0.641 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 8.07 |\n", + "| n_updates | 20 |\n", + "| policy_gradient_loss | -0.0114 |\n", + "| std | 1 |\n", + "| value_loss | 49.9 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 495 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 16 |\n", + "| total_timesteps | 8192 |\n", + "| train/ | |\n", + "| approx_kl | 0.010809683 |\n", + "| clip_fraction | 0.085 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.54 |\n", + "| explained_variance | 0.731 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 4.86 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.0133 |\n", + "| std | 1.01 |\n", + "| value_loss | 35.4 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 493 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 20 |\n", + "| total_timesteps | 10240 |\n", + "| train/ | |\n", + "| approx_kl | 0.0126280505 |\n", + "| clip_fraction | 0.113 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.54 |\n", + "| explained_variance | 0.848 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 2.63 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.0143 |\n", + "| std | 1 |\n", + "| value_loss | 21.6 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 489 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 25 |\n", + "| total_timesteps | 12288 |\n", + "| train/ | |\n", + "| approx_kl | 0.010566739 |\n", + "| clip_fraction | 0.0949 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.5 |\n", + "| explained_variance | 0.917 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.961 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.00888 |\n", + "| std | 0.994 |\n", + "| value_loss | 12.3 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 490 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 29 |\n", + "| total_timesteps | 14336 |\n", + "| train/ | |\n", + "| approx_kl | 0.0142566655 |\n", + "| clip_fraction | 0.14 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.45 |\n", + "| explained_variance | 0.957 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.301 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.0113 |\n", + "| std | 0.984 |\n", + "| value_loss | 5.71 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.5 |\n", + "| time/ | |\n", + "| fps | 492 |\n", + "| iterations | 8 |\n", + "| time_elapsed | 33 |\n", + "| total_timesteps | 16384 |\n", + "| train/ | |\n", + "| approx_kl | 0.015264266 |\n", + "| clip_fraction | 0.148 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.4 |\n", + "| explained_variance | 0.981 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.15 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.0133 |\n", + "| std | 0.982 |\n", + "| value_loss | 2.23 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 491 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 37 |\n", + "| total_timesteps | 18432 |\n", + "| train/ | |\n", + "| approx_kl | 0.017051091 |\n", + "| clip_fraction | 0.175 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.38 |\n", + "| explained_variance | 0.992 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.107 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.019 |\n", + "| std | 0.974 |\n", + "| value_loss | 1.01 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 491 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 41 |\n", + "| total_timesteps | 20480 |\n", + "| train/ | |\n", + "| approx_kl | 0.016767085 |\n", + "| clip_fraction | 0.199 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.33 |\n", + "| explained_variance | 0.996 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.0812 |\n", + "| n_updates | 90 |\n", + "| policy_gradient_loss | -0.0207 |\n", + "| std | 0.968 |\n", + "| value_loss | 0.538 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 491 |\n", + "| iterations | 11 |\n", + "| time_elapsed | 45 |\n", + "| total_timesteps | 22528 |\n", + "| train/ | |\n", + "| approx_kl | 0.016128123 |\n", + "| clip_fraction | 0.186 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.28 |\n", + "| explained_variance | 0.995 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.29 |\n", + "| n_updates | 100 |\n", + "| policy_gradient_loss | -0.0204 |\n", + "| std | 0.96 |\n", + "| value_loss | 0.619 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 491 |\n", + "| iterations | 12 |\n", + "| time_elapsed | 50 |\n", + "| total_timesteps | 24576 |\n", + "| train/ | |\n", + "| approx_kl | 0.014264241 |\n", + "| clip_fraction | 0.183 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.26 |\n", + "| explained_variance | 0.995 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.233 |\n", + "| n_updates | 110 |\n", + "| policy_gradient_loss | -0.02 |\n", + "| std | 0.957 |\n", + "| value_loss | 0.672 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.6 |\n", + "| time/ | |\n", + "| fps | 490 |\n", + "| iterations | 13 |\n", + "| time_elapsed | 54 |\n", + "| total_timesteps | 26624 |\n", + "| train/ | |\n", + "| approx_kl | 0.01653198 |\n", + "| clip_fraction | 0.199 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.23 |\n", + "| explained_variance | 0.996 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.528 |\n", + "| n_updates | 120 |\n", + "| policy_gradient_loss | -0.0205 |\n", + "| std | 0.954 |\n", + "| value_loss | 0.824 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.8 |\n", + "| time/ | |\n", + "| fps | 490 |\n", + "| iterations | 14 |\n", + "| time_elapsed | 58 |\n", + "| total_timesteps | 28672 |\n", + "| train/ | |\n", + "| approx_kl | 0.018434964 |\n", + "| clip_fraction | 0.207 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.21 |\n", + "| explained_variance | 0.994 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.299 |\n", + "| n_updates | 130 |\n", + "| policy_gradient_loss | -0.0274 |\n", + "| std | 0.95 |\n", + "| value_loss | 0.804 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -88.9 |\n", + "| time/ | |\n", + "| fps | 488 |\n", + "| iterations | 15 |\n", + "| time_elapsed | 62 |\n", + "| total_timesteps | 30720 |\n", + "| train/ | |\n", + "| approx_kl | 0.018158613 |\n", + "| clip_fraction | 0.196 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.17 |\n", + "| explained_variance | 0.995 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.368 |\n", + "| n_updates | 140 |\n", + "| policy_gradient_loss | -0.0249 |\n", + "| std | 0.944 |\n", + "| value_loss | 0.86 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -89.1 |\n", + "| time/ | |\n", + "| fps | 486 |\n", + "| iterations | 16 |\n", + "| time_elapsed | 67 |\n", + "| total_timesteps | 32768 |\n", + "| train/ | |\n", + "| approx_kl | 0.017807882 |\n", + "| clip_fraction | 0.224 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.13 |\n", + "| explained_variance | 0.995 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.318 |\n", + "| n_updates | 150 |\n", + "| policy_gradient_loss | -0.028 |\n", + "| std | 0.939 |\n", + "| value_loss | 0.895 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -89.4 |\n", + "| time/ | |\n", + "| fps | 485 |\n", + "| iterations | 17 |\n", + "| time_elapsed | 71 |\n", + "| total_timesteps | 34816 |\n", + "| train/ | |\n", + "| approx_kl | 0.019113595 |\n", + "| clip_fraction | 0.209 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.07 |\n", + "| explained_variance | 0.995 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.192 |\n", + "| n_updates | 160 |\n", + "| policy_gradient_loss | -0.024 |\n", + "| std | 0.929 |\n", + "| value_loss | 0.63 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -89.5 |\n", + "| time/ | |\n", + "| fps | 483 |\n", + "| iterations | 18 |\n", + "| time_elapsed | 76 |\n", + "| total_timesteps | 36864 |\n", + "| train/ | |\n", + "| approx_kl | 0.02225456 |\n", + "| clip_fraction | 0.218 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.01 |\n", + "| explained_variance | 0.997 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.176 |\n", + "| n_updates | 170 |\n", + "| policy_gradient_loss | -0.0213 |\n", + "| std | 0.924 |\n", + "| value_loss | 0.488 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 100 |\n", + "| ep_rew_mean | -89.6 |\n", + "| time/ | |\n", + "| fps | 482 |\n", + "| iterations | 19 |\n", + "| time_elapsed | 80 |\n", + "| total_timesteps | 38912 |\n", + "| train/ | |\n", + "| approx_kl | 0.015925717 |\n", + "| clip_fraction | 0.172 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.96 |\n", + "| explained_variance | 0.996 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 0.223 |\n", + "| n_updates | 180 |\n", + "| policy_gradient_loss | -0.0178 |\n", + "| std | 0.915 |\n", + "| value_loss | 0.578 |\n", + "-----------------------------------------\n" + ] + } + ], + "source": [ + "# Initialize the RL environment and start training\n", + "rl_result = optimize_pulses(\n", + " objectives,\n", + " control_parameters,\n", + " tlist,\n", + " algorithm_kwargs,\n", + " optimizer_kwargs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b5fa66a4", + "metadata": {}, + "source": [ + "### Analyze the Results" + ] + }, + { + "cell_type": "markdown", + "id": "afdc0129", + "metadata": {}, + "source": [ + "After the training is complete, we can analyze the results obtained by the RL agent. \n", + "In the above window showing the output produced by Gymansium, you can observe how during training the number of steps per episode (ep_len_mean) decreases and the average reward of the episodes (ep_rew_mean) increases." + ] + }, + { + "cell_type": "markdown", + "id": "3172a7d8-5621-4366-a7a1-847aa01974cf", + "metadata": {}, + "source": [ + "We can now see the fields of the `Result` class, this includes the final infidelity, the optimized control parameters and more." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "32521c12-5ff1-4199-a9d2-bc56bfd58b51", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Control Optimization Result\n", + "--------------------------\n", + "- Started at 2024-08-26 18:54:31\n", + "- Number of objectives: 1\n", + "- Final fidelity error: 0.9678450426696734\n", + "- Final parameters: [[-3.1087526679039, -8.416411578655243, 30.0, 17.213022708892822, 14.167898297309875, 30.0], [30.0, -18.649579882621765, 30.0, 4.698044657707214, 8.504285216331482, -13.694652915000916], [30.0, 30.0, -30.0, -30.0, 0.3120405972003937, -30.0], [30.0, -9.982881546020508, -7.644322514533997, -30.0, -21.7467999458313, 22.356192469596863], [-16.941781640052795, 23.70549201965332, -5.7092490792274475, -3.1341876089572906, -30.0, -22.90980041027069], [2.145122140645981, -29.785653948783875, 30.0, 24.86776828765869, -30.0, 30.0], [18.50911259651184, 20.09882926940918, -21.215654611587524, 2.6178836077451706, -10.580383837223053, -30.0], [30.0, 28.613587617874146, 30.0, 1.297985091805458, -11.024004220962524, 17.568976879119873], [-12.576630413532257, 4.3459005653858185, -10.798923969268799, -3.521493226289749, 24.44423496723175, 9.873266816139221], [3.5593581944704056, 30.0, -30.0, 30.0, -29.55104112625122, 2.953566312789917], [-17.086318731307983, -30.0, 30.0, -30.0, -2.630687803030014, -30.0], [-30.0, -8.527162671089172, 30.0, 3.3770348131656647, -17.610286474227905, 30.0], [21.06366991996765, -30.0, -30.0, 28.026187419891357, -30.0, -30.0], [-0.5829647183418274, -30.0, 0.39471566677093506, -29.881417751312256, -14.624554216861725, 30.0], [3.309006616473198, -20.649999976158142, -27.216011881828308, -15.278343558311462, -7.579786777496338, -30.0], [30.0, -25.959577560424805, -8.394513130187988, -30.0, 19.242825508117676, 30.0], [8.013235330581665, -24.82700228691101, 30.0, 15.088427066802979, 22.54744291305542, -0.40182530879974365], [27.49415874481201, -24.548481702804565, 17.323461771011353, -30.0, -0.5638691782951355, -14.158146679401398], [-1.352214775979519, 30.0, 30.0, 30.0, 2.332618832588196, -11.007753610610962], [10.588710308074951, 7.763637900352478, 8.27732115983963, -30.0, -0.5638763308525085, 12.999211549758911], [-30.0, -0.6728419661521912, 26.35693609714508, 30.0, -30.0, 9.728605449199677], [-13.936073184013367, -30.0, -11.035144329071045, 30.0, -18.027353882789612, -30.0], [30.0, -22.465408444404602, 0.19662916660308838, 2.70523801445961, -30.0, 30.0], [-18.623807430267334, 20.45433282852173, -25.64580202102661, 30.0, -30.0, -30.0], [30.0, -30.0, 22.24388301372528, 8.658326268196106, -9.13745105266571, 30.0], [-30.0, 21.34746015071869, -30.0, 20.93157649040222, -30.0, -11.851279735565186], [30.0, 8.799038529396057, -30.0, 30.0, 30.0, 5.615948438644409], [-18.51969301700592, -4.786216467618942, -30.0, -11.014840006828308, -30.0, -30.0], [1.8015557527542114, 30.0, 25.55101454257965, 30.0, 7.317472994327545, 7.97799825668335], [10.583961009979248, -30.0, 30.0, 18.709394931793213, -30.0, 30.0], [5.546256601810455, 30.0, -30.0, -30.0, -9.384287595748901, -30.0], [2.721581608057022, 30.0, -17.415393590927124, 30.0, -30.0, 7.490333318710327], [-22.799490094184875, 30.0, 12.259226739406586, 26.91823661327362, -1.8036991730332375, 1.768445000052452], [-8.539204895496368, 30.0, 21.05818748474121, -8.289629817008972, -30.0, -29.166247844696045], [-28.006927371025085, -10.715399980545044, -20.410330295562744, 7.698541581630707, 20.006970763206482, -30.0], [19.758925437927246, 19.55112934112549, 3.5737384110689163, 30.0, -12.892876267433167, -23.32941770553589], [2.4562114477157593, 30.0, 9.713201522827148, 3.1446851044893265, 16.7449289560318, 20.437912344932556], [-30.0, -30.0, 1.2014302611351013, -7.873741686344147, -30.0, -4.170540869235992], [-21.027299165725708, -30.0, 30.0, 30.0, -1.749400980770588, 28.96402359008789], [-30.0, -30.0, 10.804772078990936, -30.0, -23.41088891029358, -30.0], [30.0, 28.936269879341125, 30.0, 12.618644535541534, 26.150214672088623, 30.0], [-23.830032348632812, 12.627696096897125, 6.235614717006683, 30.0, -30.0, -30.0], [30.0, -14.821178913116455, 28.844250440597534, 4.8175786435604095, -14.676317274570465, 5.136289000511169], [-30.0, 30.0, -3.2001489400863647, 8.161829710006714, 30.0, 30.0], [-11.257586181163788, 30.0, 22.495551109313965, 14.006163775920868, -30.0, -9.58780825138092], [30.0, -3.2841461151838303, -2.8722088783979416, -30.0, -22.224994897842407, -20.68937122821808], [-3.8243257999420166, 30.0, -13.638724386692047, 30.0, -9.664212763309479, -20.168482661247253], [3.2078522443771362, 6.73091322183609, 27.41309881210327, 8.935953676700592, 0.37857480347156525, 1.48747980594635], [5.261540114879608, 23.693813681602478, 23.72292995452881, -30.0, -10.060108602046967, 30.0], [-6.126531958580017, 4.584648460149765, -13.129574060440063, 18.980677127838135, 30.0, -30.0], [30.0, -17.47907280921936, 5.239324271678925, 30.0, 18.152443170547485, 7.376201748847961], [-30.0, -30.0, 10.558403134346008, -11.307235658168793, 30.0, 22.211179733276367], [29.367637038230896, -11.977027952671051, 11.857212781906128, -20.733482837677002, 3.149323984980583, -9.324101507663727], [23.380256295204163, -22.276305556297302, 30.0, -9.564227163791656, -30.0, 24.627222418785095], [30.0, 22.601636052131653, -16.67562246322632, -20.180422067642212, -10.760060548782349, -30.0], [-30.0, -10.561510026454926, 8.180200159549713, 9.38538283109665, 20.21921396255493, 20.276060700416565], [30.0, -6.746397614479065, -21.764763593673706, -0.12885242700576782, 11.00729763507843, -30.0], [-9.224560260772705, -8.080736696720123, 30.0, 30.0, -1.0702532529830933, 1.7416155338287354], [5.196931958198547, 9.266873002052307, 30.0, 30.0, -9.557740688323975, 30.0], [16.014232635498047, 15.875131487846375, -7.712865471839905, 26.462870836257935, -10.527227818965912, -30.0], [-20.52076756954193, -30.0, 30.0, 22.811327576637268, 30.0, 20.069124698638916], [-23.37564468383789, -30.0, 24.346659779548645, -22.068958282470703, -9.679112434387207, -7.289157807826996], [-4.739203155040741, -30.0, -17.470232248306274, -16.068751215934753, 22.664607167243958, 17.942445874214172], [-10.336931347846985, -19.98522698879242, -24.442450404167175, -25.539121627807617, -30.0, -30.0], [-5.9700992703437805, -30.0, 22.532195448875427, 15.415359735488892, -25.71263015270233, 30.0], [30.0, -4.535705745220184, 30.0, 21.750606894493103, -30.0, -30.0], [-26.964715719223022, -15.222809314727783, 22.532747983932495, 30.0, -28.782237768173218, 30.0], [30.0, -3.271692842245102, 7.541305124759674, 23.899844884872437, -30.0, -19.538832306861877], [30.0, 11.78169071674347, 9.983564615249634, 17.42326319217682, 16.275426149368286, 22.99768567085266], [-10.316378474235535, -30.0, -12.636566162109375, 28.725385665893555, -21.886146068572998, -30.0], [-30.0, 25.093143582344055, -29.264066219329834, 24.33700919151306, -30.0, 30.0], [1.1465048789978027, 0.2423576731234789, -8.476797938346863, 30.0, -30.0, -20.919764041900635], [-14.426280856132507, 28.92910659313202, -30.0, -30.0, 15.197673439979553, 3.501485288143158], [24.403620958328247, -12.611689567565918, 20.203712582588196, 25.027851462364197, -5.033597052097321, -14.513703882694244], [-11.130014955997467, -21.518633365631104, -3.0117693543434143, 9.756401181221008, 30.0, 30.0], [29.35376286506653, 25.172159671783447, 8.696078360080719, -12.456485331058502, -14.965754449367523, -30.0], [30.0, -6.100337952375412, 2.966141849756241, -24.583967328071594, -3.6252542585134506, 2.9131007194519043], [9.153912663459778, -13.541809916496277, -6.093204617500305, -16.800130605697632, 10.854100584983826, 20.1009339094162], [-26.10504984855652, 26.41999840736389, 18.10444414615631, 9.271708130836487, -14.027472138404846, 15.704569816589355], [-0.4124414920806885, -30.0, 24.955875277519226, -18.137421011924744, 30.0, -15.995977520942688], [30.0, -5.882952064275742, -16.29595935344696, 6.000301837921143, 29.16996717453003, -30.0], [2.2801566123962402, 8.071172833442688, 2.7869027853012085, -21.832353472709656, -24.894644021987915, 24.334484338760376], [21.568551063537598, 28.29891085624695, -1.2894871830940247, -30.0, 25.257071256637573, -19.727632999420166], [2.6807785034179688, -26.929922103881836, -30.0, 0.5336201190948486, 28.995829224586487, 30.0], [17.6004695892334, 30.0, 14.325123131275177, 6.6842275857925415, 6.812509596347809, -29.09172534942627], [30.0, 22.988394498825073, 29.93071138858795, -30.0, -24.49855327606201, 30.0], [12.698577046394348, -27.501500844955444, 9.532395601272583, -26.67670726776123, 30.0, -30.0], [30.0, 8.755772709846497, 27.297150492668152, -17.473171949386597, -19.227054119110107, 19.75345730781555], [-16.802916526794434, 30.0, 23.78861367702484, -30.0, 7.9971760511398315, -13.027499914169312], [29.833139777183533, -5.583303272724152, 6.027087718248367, 30.0, 18.115296363830566, 30.0], [-25.46385169029236, 5.029732435941696, -30.0, 0.0949627161026001, 3.2119931280612946, -30.0], [-27.962322235107422, 30.0, 5.715133398771286, -2.7688899636268616, -12.89980262517929, 30.0], [-19.072067141532898, 14.59810495376587, 20.751537680625916, -12.960581481456757, -18.418128490447998, -25.064170360565186], [27.535747289657593, -2.6648537814617157, 30.0, 30.0, -23.544312715530396, -8.650515675544739], [30.0, -6.306121498346329, -30.0, 8.04179459810257, -14.27808701992035, 29.375606775283813], [-13.90311062335968, 22.851226329803467, -30.0, 6.758987009525299, 18.53027880191803, -24.978879690170288], [28.30285906791687, -30.0, -30.0, 30.0, 22.073732614517212, 12.210258543491364], [-30.0, -25.621742606163025, 7.907553613185883, -2.5500932335853577, -30.0, 0.6186537444591522], [-30.0, 14.853196442127228, -3.790283203125, -2.2404947876930237, 11.038969159126282, 5.507649779319763], [-10.339385569095612, 30.0, 18.92696499824524, 15.41258454322815, -30.0, -23.514423966407776], 83.0]\n", + "- Number of iterations: 400\n", + "- Reason for termination: Reached 400 episodes, stopping training.\n", + "- Optimized time parameter: 83.0\n", + "- Ended at 2024-08-26 18:55:54 (83.0s)\n" + ] + } + ], + "source": [ + "print(rl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0742940e-889c-4bac-bc5b-61ed07a78c46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'hinton')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# We can show the hinton matrix\n", + "fig, ax = qt.hinton(rl_result._final_states[0])\n", + "ax.set_title('hinton')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5802870-966c-4dd6-9fdc-df971cbfb0c7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}