Skip to content

Commit

Permalink
Merge pull request #109 from LucasAlegre/chore/update-gymnasium-1.0
Browse files Browse the repository at this point in the history
Migration to gymnasium 1.0
  • Loading branch information
ffelten authored Oct 16, 2024
2 parents 0d9b21e + fb05165 commit 76147d3
Show file tree
Hide file tree
Showing 40 changed files with 149 additions and 92 deletions.
36 changes: 10 additions & 26 deletions .github/workflows/build-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
#
# derived from https://github.com/Farama-Foundation/PettingZoo/blob/e230f4d80a5df3baf9bd905149f6d4e8ce22be31/.github/workflows/build-publish.yml
name: build-publish
name: Build artifact for PyPI

on:
push:
Expand All @@ -16,35 +16,18 @@ on:

jobs:
build-wheels:
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- os: ubuntu-latest
python: 38
platform: manylinux_x86_64
- os: ubuntu-latest
python: 39
platform: manylinux_x86_64
- os: ubuntu-latest
python: 310
platform: manylinux_x86_64
- os: ubuntu-latest
python: 311
platform: manylinux_x86_64
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

- name: Install dependencies
run: python -m pip install --upgrade pip setuptools build
run: pipx install build
- name: Build sdist and wheels
run: python -m build
run: pyproject-build
- name: Store wheels
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
path: dist

Expand All @@ -55,10 +38,11 @@ jobs:
if: github.event_name == 'release' && github.event.action == 'published'
steps:
- name: Download dists
uses: actions/download-artifact@v4.1.7
uses: actions/download-artifact@v4
with:
name: artifact
path: dist

- name: Publish
uses: pypa/gh-action-pypi-publish@release/v1
with:
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- run: python -m pip install pre-commit
- run: python -m pre_commit --version
- run: python -m pre_commit install
- run: python -m pre_commit run --all-files
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- run: pipx install pre-commit
- run: pre-commit run --all-files
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -18,13 +18,13 @@ repos:
- id: detect-private-key
- id: debug-statements
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.3.0
hooks:
- id: codespell
args:
- --ignore-words-list=reacher,ure,referenc,wile,mor,ser,esr,nowe
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args:
Expand All @@ -35,16 +35,16 @@ repos:
- --show-source
- --statistics
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.18.0
hooks:
- id: pyupgrade
args: ["--py37-plus"]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/python/black
rev: 23.1.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
Expand Down
2 changes: 1 addition & 1 deletion examples/envelope_minecart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.multi_policy.envelope.envelope import Envelope

Expand Down
2 changes: 1 addition & 1 deletion examples/eupg_fishwood.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import mo_gymnasium as mo_gym
import numpy as np
import torch as th
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import eval_mo_reward_conditioned
from morl_baselines.single_policy.esr.eupg import EUPG
Expand Down
2 changes: 1 addition & 1 deletion examples/mo_q_learning_DST.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import eval_mo
from morl_baselines.common.scalarization import tchebicheff
Expand Down
2 changes: 1 addition & 1 deletion examples/mp_mo_q_learning_DST.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.scalarization import tchebicheff
from morl_baselines.multi_policy.multi_policy_moqlearning.mp_mo_q_learning import (
Expand Down
2 changes: 1 addition & 1 deletion examples/pcn_minecart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.multi_policy.pcn.pcn import PCN

Expand Down
2 changes: 1 addition & 1 deletion examples/pgmorl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
algo.train(
total_timesteps=int(5e6),
eval_env=make_env(env_id, 42, 0, "PGMORL_eval_env", gamma=0.995)(),
ref_point=np.array([0.0, -5.0]),
ref_point=np.array([-100.0, -100.0]),
known_pareto_front=None,
)
env = make_env(env_id, 422, 1, "PGMORL_test", gamma=0.995)() # idx != 0 to avoid taking videos
Expand Down
13 changes: 7 additions & 6 deletions experiments/benchmark/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import numpy as np
import requests
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FlattenObservation
from gymnasium.wrappers.record_video import RecordVideo
from mo_gymnasium.utils import MORecordEpisodeStatistics
from gymnasium.wrappers import FlattenObservation, RecordVideo
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import seed_everything
from morl_baselines.common.experiments import (
Expand Down Expand Up @@ -90,13 +89,15 @@ def autotag() -> str:
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
try:
# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/search/issues?q=repo:LucasAlegre/morl-baselines+is:pr+{git_commit}")
prs = requests.get(
f"https://api.github.com/search/issues?q=repo:LucasAlegre/morl-baselines+is:pr+{git_commit}" # noqa
)
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
pr = prs["items"][0]
pr_number = pr["number"]
wandb_tag += f",pr-{pr_number}"
wandb_tag += f",pr-{pr_number}" # noqa
print(f"identified github pull request: {pr_number}")
except Exception as e:
print(e)
Expand Down Expand Up @@ -165,7 +166,7 @@ def wrap_mario(env):
TimeLimit,
)
from mo_gymnasium.envs.mario.joypad_space import JoypadSpace
from mo_gymnasium.utils import MOMaxAndSkipObservation
from mo_gymnasium.wrappers import MOMaxAndSkipObservation

env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = MOMaxAndSkipObservation(env, skip=4)
Expand Down
2 changes: 1 addition & 1 deletion experiments/hyperparameter_search/launch_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import wandb
import yaml
from mo_gymnasium.utils import MORecordEpisodeStatistics
from mo_gymnasium.wrappers import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import seed_everything
from morl_baselines.common.experiments import (
Expand Down
3 changes: 1 addition & 2 deletions morl_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""MORL-Baselines contains various MORL algorithms and utility functions."""


__version__ = "1.0.0"
__version__ = "1.1.0"
1 change: 1 addition & 0 deletions morl_baselines/common/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Replay buffer for multi-objective reinforcement learning."""

import numpy as np
import torch as th

Expand Down
8 changes: 6 additions & 2 deletions morl_baselines/common/diverse_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Diverse Experience Replay Buffer. Code extracted from https://github.com/axelabels/DynMORL."""

from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -154,7 +155,7 @@ def update(self, idx: int, p, tree_id=None):
Keyword Arguments:
tree_id {object} -- Tree to be updated (default: {None})
"""
if type(p) == dict:
if isinstance(p, dict):
for k in p:
self.update(idx, p[k], k)
return
Expand Down Expand Up @@ -476,7 +477,10 @@ def get_data(self, include_indices: bool = False):
Returns:
The data
"""
all_data = list(np.arange(self.capacity) + self.capacity - 1), list(self.tree.data)
all_data = (
list(np.arange(self.capacity) + self.capacity - 1),
list(self.tree.data),
)
indices = []
data = []
for i, d in zip(all_data[0], all_data[1]):
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities related to evaluation."""

import os
import random
from typing import List, Optional, Tuple
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common experiment utilities."""

import argparse

from morl_baselines.multi_policy.capql.capql import CAPQL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Probabilistic ensemble of neural networks."""

import os

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions morl_baselines/common/model_based/tabular_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tabular dynamics model S_{t+1}, R_t ~ m(.,.|s,a) ."""

import random

import numpy as np
Expand Down
53 changes: 44 additions & 9 deletions morl_baselines/common/model_based/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions for the model."""

from typing import Tuple

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -34,7 +35,7 @@ def termination_fn_dst(obs, act, next_obs):


def termination_fn_mountaincar(obs, act, next_obs):
"""Termination function of mountin car."""
"""Termination function of mountain car."""
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
position = next_obs[:, 0]
velocity = next_obs[:, 1]
Expand Down Expand Up @@ -147,16 +148,29 @@ def step(
var_obs = var_obs[0]
var_rewards = var_rewards[0]

info = {"uncertainty": uncertainties, "var_obs": var_obs, "var_rewards": var_rewards}
info = {
"uncertainty": uncertainties,
"var_obs": var_obs,
"var_rewards": var_rewards,
}

# info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
return next_obs, rewards, terminals, info


def visualize_eval(
agent, env, model=None, w=None, horizon=10, init_obs=None, compound=True, deterministic=False, show=False, filename=None
agent,
env,
model=None,
w=None,
horizon=10,
init_obs=None,
compound=True,
deterministic=False,
show=False,
filename=None,
):
"""Generates a plot of the evolution of the state, reward and model predicitions ove time.
"""Generates a plot of the evolution of the state, reward and model predictions over time.
Args:
agent: agent to be evaluated
Expand Down Expand Up @@ -213,10 +227,16 @@ def visualize_eval(
acts = F.one_hot(acts, num_classes=env.action_space.n).squeeze(1)
for step in range(len(real_obs)):
if compound or step == 0:
obs, r, done, info = model_env.step(th.tensor(obs).to(agent.device), acts[step], deterministic=deterministic)
obs, r, done, info = model_env.step(
th.tensor(obs).to(agent.device),
acts[step],
deterministic=deterministic,
)
else:
obs, r, done, info = model_env.step(
th.tensor(real_obs[step - 1]).to(agent.device), acts[step], deterministic=deterministic
th.tensor(real_obs[step - 1]).to(agent.device),
acts[step],
deterministic=deterministic,
)
model_obs.append(obs.copy())
model_obs_stds.append(np.sqrt(info["var_obs"].copy()))
Expand All @@ -240,11 +260,26 @@ def visualize_eval(
axs[i].set_ylabel(f"Reward {i - obs_dim}")
axs[i].grid(alpha=0.25)
if w is not None:
axs[i].plot(x, [real_vec_rewards[step][i - obs_dim] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_vec_rewards[step][i - obs_dim] for step in x],
label="Environment",
color="black",
)
else:
axs[i].plot(x, [real_rewards[step] for step in x], label="Environment", color="black")
axs[i].plot(
x,
[real_rewards[step] for step in x],
label="Environment",
color="black",
)
if model is not None:
axs[i].plot(x, [model_rewards[step][i - obs_dim] for step in x], label="Model", color="blue")
axs[i].plot(
x,
[model_rewards[step][i - obs_dim] for step in x],
label="Model",
color="blue",
)
axs[i].fill_between(
x,
[model_rewards[step][i - obs_dim] + model_rewards_stds[step][i - obs_dim] for step in x],
Expand Down
3 changes: 2 additions & 1 deletion morl_baselines/common/morl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MORL algorithm base classes."""

import os
import time
from abc import ABC, abstractmethod
Expand All @@ -11,7 +12,7 @@
import torch.nn
import wandb
from gymnasium import spaces
from mo_gymnasium.utils import MOSyncVectorEnv
from mo_gymnasium.wrappers.vector import MOSyncVectorEnv

from morl_baselines.common.evaluation import (
eval_mo_reward_conditioned,
Expand Down
Loading

0 comments on commit 76147d3

Please # to comment.