From d3e0024e2677b42977369cba672eccc9cc010b9c Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 10 Oct 2024 20:49:24 +0200 Subject: [PATCH] [RLlib] Add framework-check to `MultiRLModule.add_module()`. (#47973) Signed-off-by: ujjawal-khare --- rllib/core/rl_module/multi_rl_module.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index 43eddb909dea..b140d9d13a96 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -288,6 +288,19 @@ def add_module( # has `inference_only=False`. if not module.inference_only: self.inference_only = False + + # Check framework of incoming RLModule against `self.framework`. + if module.framework is not None: + if self.framework is None: + self.framework = module.framework + elif module.framework != self.framework: + raise ValueError( + f"Framework ({module.framework}) of incoming RLModule does NOT " + f"match framework ({self.framework}) of MultiRLModule! If the " + f"added module should not be trained, try setting its framework " + f"to None." + ) + self._rl_modules[module_id] = module # Update our RLModuleSpecs dict, such that - if written to disk - # it'll allow for proper restoring this instance through `.from_checkpoint()`.