Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Wandb #1

Merged
merged 6 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions make_dmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#adapted from https://github.com/imgeorgiev/dmc2gymnasium

import logging
import numpy as np

from dm_control import suite
from dm_env import specs
from gymnasium.core import Env
from gymnasium.spaces import Box
from gymnasium import spaces
from gymnasium.wrappers import FlattenObservation, RescaleAction

def _spec_to_box(spec, dtype=np.float32):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = int(np.prod(s.shape))
if type(s) == specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32)
return -bound, bound
elif type(s) == specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32)
return s.minimum + zeros, s.maximum + zeros
else:
logging.error("Unrecognized type")
mins, maxs = [], []
for s in spec:
mn, mx = extract_min_max(s)
mins.append(mn)
maxs.append(mx)
low = np.concatenate(mins, axis=0).astype(dtype)
high = np.concatenate(maxs, axis=0).astype(dtype)
assert low.shape == high.shape
return Box(low, high, dtype=dtype)


def _flatten_obs(obs, dtype=np.float32):
obs_pieces = []
for v in obs.values():
flat = np.array([v]) if np.isscalar(v) else v.ravel()
obs_pieces.append(flat)
return np.concatenate(obs_pieces, axis=0).astype(dtype)


class DMCGym(Env):
def __init__(
self,
env_name,
task_kwargs={},
environment_kwargs={},
#rendering="egl",
render_height=64,
render_width=64,
render_camera_id=0,
action_repeat=1
):
domain = env_name.split('-')[0]
task = env_name.split('-')[1]
self._env = suite.load(
domain,
task,
task_kwargs,
environment_kwargs,
)

# placeholder to allow built in gymnasium rendering
self.render_mode = "rgb_array"
self.render_height = render_height
self.render_width = render_width
self.render_camera_id = render_camera_id

self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
self._norm_action_space = spaces.Box(
low=-1.0,
high=1.0,
shape=self._true_action_space.shape,
dtype=np.float32
)

self._observation_space = _spec_to_box(self._env.observation_spec().values())
self._action_space = _spec_to_box([self._env.action_spec()])
self.action_repeat = action_repeat

# set seed if provided with task_kwargs
if "random" in task_kwargs:
seed = task_kwargs["random"]
self._observation_space.seed(seed)
self._action_space.seed(seed)

def __getattr__(self, name):
"""Add this here so that we can easily access attributes of the underlying env"""
return getattr(self._env, name)

@property
def observation_space(self):
return self._observation_space

@property
def action_space(self):
return self._action_space

@property
def reward_range(self):
"""DMC always has a per-step reward range of (0, 1)"""
return 0, 1

def _convert_action(self, action):
action = action.astype(np.float64)
true_delta = self._true_action_space.high - self._true_action_space.low
norm_delta = self._norm_action_space.high - self._norm_action_space.low
action = (action - self._norm_action_space.low) / norm_delta
action = action * true_delta + self._true_action_space.low
action = action.astype(np.float32)
return action

def step(self, action):
assert self._norm_action_space.contains(action)
action = self._convert_action(action)
assert self._true_action_space.contains(action)
action = np.clip(action, -1.0, 1.0)
reward = 0
info = {}
for i in range(self.action_repeat):
timestep = self._env.step(action)
observation = _flatten_obs(timestep.observation)
reward += timestep.reward
termination = False # we never reach a goal
truncation = timestep.last()
if truncation:
return observation, reward, termination, truncation, info
return observation, reward, termination, truncation, info

def reset(self, seed=None, options=None):
if seed is not None:
if not isinstance(seed, np.random.RandomState):
seed = np.random.RandomState(seed)
self._env.task._random = seed
if options:
logging.warn("Currently doing nothing with options={:}".format(options))
timestep = self._env.reset()
observation = _flatten_obs(timestep.observation)
info = {}
return observation, info

def render(self, height=None, width=None, camera_id=None):
height = height or self.render_height
width = width or self.render_width
camera_id = camera_id or self.render_camera_id
return self._env.physics.render(height=height, width=width, camera_id=camera_id)


def make_env_dmc(env_name: str, action_repeat: int = 1) -> Env:
env = DMCGym(env_name=env_name, action_repeat=action_repeat)
env = RescaleAction(env, -1.0, 1.0)
env = FlattenObservation(env)
return env


3 changes: 3 additions & 0 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sbx.sac import SAC
from sbx.td3 import TD3
from sbx.tqc import TQC
from sbx.bro import BRO


# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand All @@ -30,4 +32,5 @@ def DroQ(*args, **kwargs):
"SAC",
"TD3",
"TQC",
"BRO",
]
133 changes: 124 additions & 9 deletions sbx/bro/bro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from sbx.bro.policies import BROPolicy



class EntropyCoef(nn.Module):
ent_coef_init: float = 1.0

@nn.compact
def __call__(self) -> jnp.ndarray:
log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init)))
return jnp.exp(log_ent_coef)



class ConstantEntropyCoef(nn.Module):
ent_coef_init: float = 1.0

Expand All @@ -38,7 +38,27 @@ def __call__(self) -> float:
self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init))
return self.ent_coef_init


@jax.jit
def _get_stats(
actor_state: RLTrainState,
qf_state: RLTrainState,
ent_coef_state: TrainState,
observations: jax.Array,
key: jax.Array,
):
key, dropout_key, noise_key = jax.random.split(key, 3)
dist = actor_state.apply_fn(actor_state.params, observations)
actor_actions = dist.sample(seed=noise_key)
log_prob = dist.log_prob(actor_actions).reshape(-1, 1)
qf_pi = qf_state.apply_fn(
qf_state.params,
observations,
actor_actions,
rngs={"dropout": dropout_key},
)
ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params})
return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean()

class BRO(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": BROPolicy,
Expand Down Expand Up @@ -109,18 +129,75 @@ def __init__(
self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy
self.init_key = jax.random.PRNGKey(seed)

self.n_quantiles = n_quantiles
taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles
self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None]


self.distributional = True if self.n_quantiles > 1 else False
if _init_setup_model:
self._setup_model()

def _setup_model(self) -> None:
super()._setup_model()

if not hasattr(self, "policy") or self.policy is None:
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
self.n_quantiles,
**self.policy_kwargs,
)

assert isinstance(self.qf_learning_rate, float)
self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate)

self.key, ent_key = jax.random.split(self.key, 2)

self.actor = self.policy.actor # type: ignore[assignment]
self.qf = self.policy.qf # type: ignore[assignment]

# The entropy coefficient or entropy can be learned automatically
# see Automating Entropy Adjustment for Maximum Entropy RL section
# of https://arxiv.org/abs/1812.05905
if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"):
# Default initial value of ent_coef when learned
ent_coef_init = 1.0
if "_" in self.ent_coef_init:
ent_coef_init = float(self.ent_coef_init.split("_")[1])
assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0"

# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
self.ent_coef = EntropyCoef(ent_coef_init)
else:
# This will throw an error if a malformed string (different from 'auto') is passed
assert isinstance(
self.ent_coef_init, float
), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}"
self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment]

self.ent_coef_state = TrainState.create(
apply_fn=self.ent_coef.apply,
params=self.ent_coef.init(ent_key)["params"],
tx=optax.adam(
learning_rate=self.learning_rate, b1=0.5
),
)

# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def reset(self):
if not hasattr(self, "policy") or self.policy is None:
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
Expand All @@ -132,7 +209,7 @@ def _setup_model(self) -> None:

assert isinstance(self.qf_learning_rate, float)

self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate)

self.key, ent_key = jax.random.split(self.key, 2)

Expand Down Expand Up @@ -163,7 +240,7 @@ def _setup_model(self) -> None:
apply_fn=self.ent_coef.apply,
params=self.ent_coef.init(ent_key)["params"],
tx=optax.adam(
learning_rate=self.learning_rate,
learning_rate=self.learning_rate, b1=0.5
),
)

Expand Down Expand Up @@ -238,10 +315,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.key,
)
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
return {
'actor_loss': actor_loss_value.item(),
'critic_loss': qf_loss_value.item(),
'ent_loss': ent_coef_value.item(),
}
#self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
#self.logger.record("train/actor_loss", actor_loss_value.item())
#self.logger.record("train/critic_loss", qf_loss_value.item())
#self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
@jax.jit
Expand Down Expand Up @@ -425,7 +507,40 @@ def update_actor_and_temperature(
)
ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy)
return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key

def get_stats(self, batch_size):
data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1)
next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1)
else:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

# Convert to numpy
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
data.dones.numpy().flatten(),
data.rewards.numpy().flatten(),
)
q, a, temp, ent = _get_stats(
self.policy.actor_state,
self.policy.qf_state,
self.ent_coef_state,
obs,
self.key,
)

return {
'q': q.mean().item(),
'a': a.item(),
'temp': temp.item(),
'entropy': ent.item()}

@classmethod
@partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"])
def _train(
Expand Down
Loading