Skip to content

Commit

Permalink
Refactor SAC policy with performance optimizations and multi-camera s…
Browse files Browse the repository at this point in the history
…upport

- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
  • Loading branch information
AdilZouitine committed Feb 20, 2025
1 parent ff47c0b commit ff82367
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 95 deletions.
200 changes: 128 additions & 72 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

# TODO: (1) better device management

from copy import deepcopy
from typing import Callable, Optional, Tuple

import einops
import numpy as np
from tensordict import from_modules
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
Expand Down Expand Up @@ -85,9 +87,9 @@ def __init__(

self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
network_list=nn.ModuleList(
ensemble=Ensemble(
[
MLP(
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
Expand All @@ -99,9 +101,9 @@ def __init__(

self.critic_target = CriticEnsemble(
encoder=encoder_critic,
network_list=nn.ModuleList(
ensemble=Ensemble(
[
MLP(
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
Expand All @@ -113,6 +115,9 @@ def __init__(

self.critic_target.load_state_dict(self.critic_ensemble.state_dict())

self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)

self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
Expand Down Expand Up @@ -274,6 +279,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))


class CriticEnsemble(nn.Module):
"""
┌──────────────────┬─────────────────────────────────────────────────────────┐
Expand Down Expand Up @@ -316,45 +350,21 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network_list: nn.ModuleList,
ensemble: "Ensemble[CriticHead]",
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.network_list = network_list
self.ensemble = ensemble
self.init_final = init_final
self.output_normalization = output_normalization

self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)

self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break

# Output layer
self.output_layers = []
if init_final is not None:
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
else:
self.output_layers = []
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)

self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
self.parameters_to_optimize += list(self.ensemble.parameters())

def forward(
self,
Expand All @@ -373,12 +383,8 @@ def forward(
obs_enc = observations if self.encoder is None else self.encoder(observations)

inputs = torch.cat([obs_enc, actions], dim=-1)
list_q_values = []
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
x = network(inputs)
value = output_layer(x)
list_q_values.append(value.squeeze(-1))
return torch.stack(list_q_values)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
return q_values.squeeze(-1) # [num_critics, B]


class Policy(nn.Module):
Expand Down Expand Up @@ -510,6 +516,7 @@ def __init__(self, config: SACConfig, input_normalizer: nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
Expand Down Expand Up @@ -546,14 +553,13 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension.
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key])

# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)

if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
Expand Down Expand Up @@ -671,6 +677,34 @@ def forward(self, x):
return x


class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""

def __init__(self, modules, **kwargs):
super().__init__()
# combine_state_for_ensemble causes graph breaks
self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0])
self._repr = str(modules[0])
self._n = len(modules)

def __len__(self):
return self._n

def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
return self.module(*args, **kwargs)

def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)

def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr


# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
Expand Down Expand Up @@ -711,46 +745,68 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:

config = SACConfig()
config.num_critics = 10
encoder = SACObservationEncoder(config)
actor_encoder = SACObservationEncoder(config)
encoder = torch.compile(encoder)
config.vision_encoder_name = None
encoder = SACObservationEncoder(config, nn.Identity())
# actor_encoder = SACObservationEncoder(config)
# encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble(
encoder=encoder,
network_list=nn.ModuleList(
ensemble=Ensemble(
[
MLP(
CriticHead(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
output_normalization=nn.Identity(),
)
actor = Policy(
encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble)
# actor = Policy(
# encoder=actor_encoder,
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
# action_dim=config.output_shapes["action"][0],
# encoder_is_shared=config.shared_encoder,
# **config.policy_kwargs,
# )
# encoder = encoder.to("cuda:0")
# critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor)
actor = actor.to("cuda:0")
# actor = torch.compile(actor)
# actor = actor.to("cuda:0")
obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84),
"observation.state": torch.randn(1, 4),
"observation.image": torch.randn(8, 3, 84, 84),
"observation.state": torch.randn(8, 4),
}
actions = torch.randn(1, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...")
# q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict)
print("compiled")
actions = torch.randn(8, 2).to("cuda:0")
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
# print("compiling...")
q_value = critic_ensemble(obs_dict, actions)
print(q_value.size())
# action = actor(obs_dict)
# print("compiled")
# start = time.perf_counter()
# for _ in range(1000):
# # features = encoder(obs_dict)
# action = actor(obs_dict)
# # q_value = critic_ensemble(obs_dict, actions)
# print("Time taken:", time.perf_counter() - start)
# Compare the performance of the ensemble vs a for loop of 16 MLPs
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
ensemble = ensemble.to("cuda:0")
critic = CriticHead(256, [256, 256])
critic = critic.to("cuda:0")
data_ensemble = torch.randn(8, 256).to("cuda:0")
ensemble = torch.compile(ensemble)
# critic = torch.compile(critic)
print(ensemble(data_ensemble).size())
print(critic(data_ensemble).size())
start = time.perf_counter()
for _ in range(1000):
ensemble(data_ensemble)
print("Time taken:", time.perf_counter() - start)
start = time.perf_counter()
for _ in range(1000):
# features = encoder(obs_dict)
action = actor(obs_dict)
# q_value = critic_ensemble(obs_dict, actions)
for i in range(2):
critic(data_ensemble)
print("Time taken:", time.perf_counter() - start)
4 changes: 2 additions & 2 deletions lerobot/configs/env/maniskill_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 128
image_size: 64
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 128
render_size: 64
device: cuda

reward_classifier:
Expand Down
42 changes: 23 additions & 19 deletions lerobot/configs/policy/sac_maniskill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,36 @@ policy:
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 128, 128]
observation.image: [3, 64, 64]
observation.image.2: [3, 64, 64]
output_shapes:
action: [7]

camera_number: 2

# Normalization / Unnormalization
input_normalization_modes:
observation.state: min_max
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]

max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
input_normalization_modes: null
# input_normalization_modes:
# observation.state: min_max
input_normalization_params: null
# observation.state:
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]

# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
# 0.4001]

output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
output_normalization_shapes:
action: [7]

Expand All @@ -94,8 +98,8 @@ policy:
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2 #10
num_subsample_critics: null
num_critics: 10 #10
num_subsample_critics: 2
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
Expand Down
Loading

0 comments on commit ff82367

Please # to comment.