diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 43eddb909dea0..b140d9d13a96c 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -288,6 +288,19 @@ def add_module( # has `inference_only=False`. if not module.inference_only: self.inference_only = False + + # Check framework of incoming RLModule against `self.framework`. + if module.framework is not None: + if self.framework is None: + self.framework = module.framework + elif module.framework != self.framework: + raise ValueError( + f"Framework ({module.framework}) of incoming RLModule does NOT " + f"match framework ({self.framework}) of MultiRLModule! If the " + f"added module should not be trained, try setting its framework " + f"to None." + ) + self._rl_modules[module_id] = module # Update our RLModuleSpecs dict, such that - if written to disk - # it'll allow for proper restoring this instance through `.from_checkpoint()`.