Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Remove MuJoCo dependency from SQIL notebook #800

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions docs/tutorials/8a_train_sqil_sac.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n",
"# Train an Agent using Soft Q Imitation Learning with SAC\n",
"\n",
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a HalfCheetah agent using SQIL + SAC."
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we need some expert trajectories in our environment (`seals/HalfCheetah-v0`).\n",
"First, we need some expert trajectories in our environment (`Pendulum-v1`).\n",
"Note that you can use other environments, but the action space must be continuous."
]
},
Expand All @@ -28,7 +28,7 @@
"from imitation.data import huggingface_utils\n",
"\n",
"# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-HalfCheetah-v0\")\n",
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n",
"\n",
"# Convert the dataset to a format usable by the imitation library.\n",
"expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
Expand Down Expand Up @@ -75,12 +75,11 @@
"from imitation.util.util import make_vec_env\n",
"import numpy as np\n",
"from stable_baselines3 import sac\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"\n",
"SEED = 42\n",
"\n",
"venv = make_vec_env(\n",
" \"seals/HalfCheetah-v0\",\n",
" \"Pendulum-v1\",\n",
" rng=np.random.default_rng(seed=SEED),\n",
")\n",
"\n",
Expand Down
Loading