diff --git a/docs/tutorials/8a_train_sqil_sac.ipynb b/docs/tutorials/8a_train_sqil_sac.ipynb index 942f8a290..7eadebaa5 100644 --- a/docs/tutorials/8a_train_sqil_sac.ipynb +++ b/docs/tutorials/8a_train_sqil_sac.ipynb @@ -7,14 +7,14 @@ "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n", "# Train an Agent using Soft Q Imitation Learning with SAC\n", "\n", - "In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a HalfCheetah agent using SQIL + SAC." + "In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "First, we need some expert trajectories in our environment (`seals/HalfCheetah-v0`).\n", + "First, we need some expert trajectories in our environment (`Pendulum-v1`).\n", "Note that you can use other environments, but the action space must be continuous." ] }, @@ -28,7 +28,7 @@ "from imitation.data import huggingface_utils\n", "\n", "# Download some expert trajectories from the HuggingFace Datasets Hub.\n", - "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-HalfCheetah-v0\")\n", + "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n", "\n", "# Convert the dataset to a format usable by the imitation library.\n", "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])" @@ -75,12 +75,11 @@ "from imitation.util.util import make_vec_env\n", "import numpy as np\n", "from stable_baselines3 import sac\n", - "import seals # noqa: F401 # needed to load \"seals/\" environments\n", "\n", "SEED = 42\n", "\n", "venv = make_vec_env(\n", - " \"seals/HalfCheetah-v1\",\n", + " \"Pendulum-v1\",\n", " rng=np.random.default_rng(seed=SEED),\n", ")\n", "\n",