Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[RLlib] New API Stack: Action masking example issues in release 2.34 - 'super' object has no attribute '_compute_values' #47361

Closed
PhilippWillms opened this issue Aug 27, 2024 · 7 comments · Fixed by #47817
Assignees
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical rllib RLlib related issues rllib-docs-or-examples Issues related to RLlib documentation or rllib/examples

Comments

@PhilippWillms
Copy link

What happened + What you expected to happen

As per my best knowledge, the repro script covers the version of action_masking_rlm.py example file which was shipped in release 2.34. However, I adjusted the resources, learners and env_runners config as per my requirement.

My assumption is that the config did not properly recognize that I want to use the GPU of the "main" / local cluster.

Complete error stack trace:

"name": "AttributeError",
"message": "'super' object has no attribute '_compute_values'",
"stack": "---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 8
6 results = []
7 for i in tqdm(range(1,iteration_count+1)):
----> 8 res = trainer.train()
9 if i % checkpoint_at_every_iter == 0:
10 path_to_checkpoint = trainer.save()

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\tune\trainable\trainable.py:331, in Trainable.train(self)
329 except Exception as e:
330 skipped = skip_exceptions(e)
--> 331 raise skipped from exception_cause(skipped)
333 assert isinstance(result, dict), "step() needs to return a dict."
335 # We do not modify internal state nor update this result if duplicate.

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\tune\trainable\trainable.py:328, in Trainable.train(self)
326 start = time.time()
327 try:
--> 328 result = self.step()
329 except Exception as e:
330 skipped = skip_exceptions(e)

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\algorithm.py:956, in Algorithm.step(self)
948 # Parallel eval + training: Kick off evaluation-loop and parallel train() call.
949 elif self.config._run_training_always_in_thread or (
950 evaluate_this_iter and self.config.evaluation_parallel_to_training
951 ):
952 (
953 train_results,
954 eval_results,
955 train_iter_ctx,
--> 956 ) = self._run_one_training_iteration_and_evaluation_in_parallel()
958 # - No evaluation necessary, just run the next training iteration.
959 # - We have to evaluate in this training iteration, but no parallelism ->
960 # evaluate after the training iteration is entirely done.
961 else:
962 train_results, train_iter_ctx = self._run_one_training_iteration()

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\algorithm.py:3572, in Algorithm._run_one_training_iteration_and_evaluation_in_parallel(self)
3568 evaluation_results = self._run_one_evaluation(
3569 parallel_train_future=parallel_train_future
3570 )
3571 # Collect the training results from the future.
-> 3572 train_results, train_iter_ctx = parallel_train_future.result()
3574 return train_results, evaluation_results, train_iter_ctx

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\concurrent\futures\_base.py:456, in Future.result(self, timeout)
454 raise CancelledError()
455 elif self._state == FINISHED:
--> 456 return self.__get_result()
457 else:
458 raise TimeoutError()

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\concurrent\futures\_base.py:401, in Future.__get_result(self)
399 if self._exception:
400 try:
--> 401 raise self._exception
402 finally:
403 # Break a reference cycle with the exception in self._exception
404 self = None

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\concurrent\futures\thread.py:58, in _WorkItem.run(self)
55 return
57 try:
---> 58 result = self.fn(*self.args, **self.kwargs)
59 except BaseException as exc:
60 self.future.set_exception(exc)

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\algorithm.py:3559, in Algorithm._run_one_training_iteration_and_evaluation_in_parallel..()
3546 """Runs one training iteration and one evaluation step in parallel.
3547
3548 First starts the training iteration (via self._run_one_training_iteration())
(...)
3555 the TrainIterCtx object returned by the training call.
3556 """
3557 with concurrent.futures.ThreadPoolExecutor() as executor:
3558 parallel_train_future = executor.submit(
-> 3559 lambda: self._run_one_training_iteration()
3560 )
3561 evaluation_results = {}
3562 # If the debug setting _run_training_always_in_thread is used, do NOT
3563 # evaluate, no matter what the settings are,

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\algorithm.py:3492, in Algorithm._run_one_training_iteration(self)
3488 # Try to train one step.
3489 with self._timers[TRAINING_STEP_TIMER]:
3490 # TODO (sven): Should we reduce the different
3491 # training_step_results over time with MetricsLogger.
-> 3492 training_step_results = self.training_step()
3494 if training_step_results:
3495 results = training_step_results

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo.py:422, in PPO.training_step(self)
418 @OverRide(Algorithm)
419 def training_step(self):
420 # New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
421 if self.config.enable_env_runner_and_connector_v2:
--> 422 return self._training_step_new_api_stack()
423 # Old and hybrid API stacks (Policy, RolloutWorker, Connector, maybe RLModule,
424 # maybe Learner).
425 else:
426 return self._training_step_old_and_hybrid_api_stacks()

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo.py:478, in PPO._training_step_new_api_stack(self)
476 # Perform a learner update step on the collected episodes.
477 with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
--> 478 learner_results = self.learner_group.update_from_episodes(
479 episodes=episodes,
480 timesteps={
481 NUM_ENV_STEPS_SAMPLED_LIFETIME: (
482 self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
483 ),
484 },
485 minibatch_size=(
486 self.config.mini_batch_size_per_learner
487 or self.config.sgd_minibatch_size
488 ),
489 num_iters=self.config.num_sgd_iter,
490 )
491 self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
492 self.metrics.log_dict(
493 {
494 NUM_ENV_STEPS_TRAINED_LIFETIME: self.metrics.peek(
(...)
501 reduce="sum",
502 )

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\core\learner\learner_group.py:342, in LearnerGroup.update_from_episodes(self, episodes, timesteps, async_update, return_state, minibatch_size, num_iters, reduce_fn, **kwargs)
330 if reduce_fn != DEPRECATED_VALUE:
331 deprecation_warning(
332 old="LearnerGroup.update_from_episodes(reduce_fn=..)",
333 new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., "
(...)
339 error=True,
340 )
--> 342 return self._update(
343 episodes=episodes,
344 timesteps=timesteps,
345 async_update=async_update,
346 return_state=return_state,
347 minibatch_size=minibatch_size,
348 num_iters=num_iters,
349 **kwargs,
350 )

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\core\learner\learner_group.py:421, in LearnerGroup._update(self, batch, episodes, timesteps, async_update, return_state, minibatch_size, num_iters, **kwargs)
414 if async_update:
415 raise ValueError(
416 "Cannot call update_from_batch(async_update=True) when running in"
417 " local mode! Try setting config.num_learners > 0."
418 )
420 results = [
--> 421 _learner_update(
422 _learner=self._learner,
423 _batch_shard=batch,
424 _episodes_shard=episodes,
425 _timesteps=timesteps,
426 _return_state=return_state,
427 **kwargs,
428 )
429 ]
430 # One or more remote Learners: Shard batch/episodes into equal pieces (roughly
431 # equal if multi-agent AND episodes) and send each Learner worker one of these
432 # shards.
(...)
439 # Then again, we might move into a world where Learner always
440 # receives Episodes, never batches.
441 if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator):

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\core\learner\learner_group.py:396, in LearnerGroup._update.._learner_update(_learner, _batch_shard, _episodes_shard, _timesteps, _return_state, _min_total_mini_batches, **_kwargs)
388 result = _learner.update_from_batch(
389 batch=_batch_shard,
390 timesteps=_timesteps,
(...)
393 **_kwargs,
394 )
395 else:
--> 396 result = _learner.update_from_episodes(
397 episodes=_episodes_shard,
398 timesteps=_timesteps,
399 minibatch_size=minibatch_size,
400 num_iters=num_iters,
401 min_total_mini_batches=_min_total_mini_batches,
402 **_kwargs,
403 )
404 if _return_state:
405 result["_rl_module_state_after_update"] = _learner.get_state(
406 components=COMPONENT_RL_MODULE, inference_only=True
407 )[COMPONENT_RL_MODULE]

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\core\learner\learner.py:1027, in Learner.update_from_episodes(self, episodes, timesteps, minibatch_size, num_iters, min_total_mini_batches, reduce_fn)
1017 if reduce_fn != DEPRECATED_VALUE:
1018 deprecation_warning(
1019 old="Learner.update_from_episodes(reduce_fn=..)",
1020 new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., "
(...)
1025 error=True,
1026 )
-> 1027 return self._update_from_batch_or_episodes(
1028 episodes=episodes,
1029 timesteps=timesteps,
1030 minibatch_size=minibatch_size,
1031 num_iters=num_iters,
1032 min_total_mini_batches=min_total_mini_batches,
1033 )

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_learner.py:71, in PPOLearner._update_from_batch_or_episodes(self, batch, episodes, **kwargs)
60 @OverRide(Learner)
61 def _update_from_batch_or_episodes(
62 self,
(...)
68 # First perform GAE computation on the entirety of the given train data (all
69 # episodes).
70 if self.config.enable_env_runner_and_connector_v2:
---> 71 batch, episodes = self._compute_gae_from_episodes(episodes=episodes)
73 # Now that GAE (advantages and value targets) have been added to the train
74 # batch, we can proceed normally (calling super method) with the update step.
75 return super()._update_from_batch_or_episodes(
76 batch=batch,
77 episodes=episodes,
78 **kwargs,
79 )

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_learner.py:129, in PPOLearner._compute_gae_from_episodes(self, episodes)
122 batch_for_vf = self._learner_connector(
123 rl_module=self.module,
124 data={},
125 episodes=episodes,
126 shared_data={},
127 )
128 # Perform the value model's forward pass.
--> 129 vf_preds = convert_to_numpy(self._compute_values(batch_for_vf))
131 for module_id, module_vf_preds in vf_preds.items():
132 # Collect new (single-agent) episode lengths.
133 episode_lens_plus_1 = [
134 len(e)
135 for e in sa_episodes_list
136 if e.module_id is None or e.module_id == module_id
137 ]

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_learner.py:267, in PPOLearner._compute_values(self, batch_for_vf)
248 @OverrideToImplementCustomLogic
249 def _compute_values(
250 self,
251 batch_for_vf: Dict[str, Any],
252 ) -> Union[TensorType, Dict[str, Any]]:
253 """Computes the value function predictions for the module being optimized.
254
255 This method must be overridden by multiagent-specific algorithm learners to
(...)
265 tensors.
266 """
--> 267 return {
268 module_id: self.module[module_id].unwrapped().compute_values(module_batch)
269 for module_id, module_batch in batch_for_vf.items()
270 if self.should_module_be_updated(module_id, batch_for_vf)
271 }

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\ppo\ppo_learner.py:268, in (.0)
248 @OverrideToImplementCustomLogic
249 def _compute_values(
250 self,
251 batch_for_vf: Dict[str, Any],
252 ) -> Union[TensorType, Dict[str, Any]]:
253 """Computes the value function predictions for the module being optimized.
254
255 This method must be overridden by multiagent-specific algorithm learners to
(...)
265 tensors.
266 """
267 return {
--> 268 module_id: self.module[module_id].unwrapped().compute_values(module_batch)
269 for module_id, module_batch in batch_for_vf.items()
270 if self.should_module_be_updated(module_id, batch_for_vf)
271 }

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\examples\rl_modules\classes\action_masking_rlm.py:108, in ActionMaskingTorchRLModule.compute_values(self, batch)
106 _, batch = self._preprocess_batch(batch)
107 # Call the super's method to compute values for GAE.
--> 108 return super()._compute_values(batch)

AttributeError: 'super' object has no attribute '_compute_values'"

Versions / Dependencies

ray==2.34
torch==2.3.1+cu118
gynmasium==0.28.1
Windows 11

Reproduction script

from gymnasium.spaces import Box, Discrete

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule,
)

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment(
        env=ActionMaskEnv,
        env_config={
            "action_space": Discrete(100),
            "observation_space": Box(-1.0, 1.0, (5,)),
        },
    )
    .framework("torch")
    .resources(num_gpus=1)
    .learners(num_learners=0, num_cpus_per_learner=0, num_gpus_per_learner=1)
    .env_runners(
        num_env_runners=4, 
        num_cpus_per_env_runner=2,
        batch_mode="complete_episodes",
    )
    .rl_module(
        model_config_dict={
            "post_fcnet_hiddens": [64, 64],
            "post_fcnet_activation": "relu",
        },
        rl_module_spec=SingleAgentRLModuleSpec(
            module_class=ActionMaskingTorchRLModule,
        ),
    )
    .evaluation(
        evaluation_num_env_runners=1,
        evaluation_interval=1,  # Every how many iterations to run one round of evaluation. 0 -> no evaluation.        
        evaluation_parallel_to_training=True,   # Run evaluation parallel to training to speed up the example.
    ) 
)

ray.shutdown()
# Initialize Ray and Build Agent
ray.init(num_cpus=12, num_gpus=1, include_dashboard=True)

checkpoint_at_every_iter = 5
iteration_count = 20

trainer = config.build()
 
results = []
for i in tqdm(range(1,iteration_count+1)):
    res = trainer.train()
    if i % checkpoint_at_every_iter == 0:
        path_to_checkpoint = trainer.save()        
        print(f"Checkpoint saved at {path_to_checkpoint}")
    results.append(res)

Issue Severity

High: It blocks me from completing my task.

@PhilippWillms PhilippWillms added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Aug 27, 2024
@PhilippWillms
Copy link
Author

Same rooted error message occurs at a different level:

.framework("torch")
    .resources(num_gpus=1)  #, num_cpus_for_main_process=1 should be relevant for tune
    .learners(num_learners=0, num_gpus_per_learner=1)
    .env_runners(
        num_env_runners=4, 
        num_cpus_per_env_runner=1,
        batch_mode="complete_episodes",
    )
    .rl_module(
        model_config_dict={
            "post_fcnet_hiddens": [64, 64],
            "post_fcnet_activation": "relu",
        },
        rl_module_spec=SingleAgentRLModuleSpec(
            module_class=ActionMaskingTorchRLModule,
        ),
    )
    .evaluation(
        evaluation_num_env_runners=1,
        evaluation_interval=1,  
        evaluation_parallel_to_training=True,
    ) 

File c:\Users\Philipp\anaconda3\envs\py311-raynew\Lib\site-packages\ray\rllib\algorithms\algorithm.py:956, in Algorithm.step(self)
948 # Parallel eval + training: Kick off evaluation-loop and parallel train() call.
...
106 _, batch = self._preprocess_batch(batch)
107 # Call the super's method to compute values for GAE.
--> 108 return super()._compute_values(batch)

AttributeError: 'super' object has no attribute '_compute_values'

@PhilippWillms PhilippWillms changed the title [RLlib] New API Stack: Action masking example issues in release 2.34 - resource configs [RLlib] New API Stack: Action masking example issues in release 2.34 - 'super' object has no attribute '_compute_values' Aug 27, 2024
@PhilippWillms
Copy link
Author

PhilippWillms commented Aug 27, 2024

Occurs also in nightly built, downloaded at 08:30 p.m. CEST.

@simonsays1980 , @sven1977 : Happens for both the trainer API (i.e. config.build().train()) and the API and tune.
Also with the shipped example, running it in CLI via python action_masking_rlm.py.

@PhilippWillms
Copy link
Author

Also happens if evaluation_parallel_to_training=True is NOT set in the config. Issue always occurs at first evaluation step.

@anyscalesam anyscalesam added the rllib RLlib related issues label Sep 3, 2024
@grizzlybearg
Copy link

modify youy class' _compute_values method to compute_values

@simonsays1980 simonsays1980 added rllib-gpu-multi-gpu RLlib issues that's related to running on one or multiple GPUs rllib-docs-or-examples Issues related to RLlib documentation or rllib/examples and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) rllib-gpu-multi-gpu RLlib issues that's related to running on one or multiple GPUs labels Sep 11, 2024
@simonsays1980 simonsays1980 self-assigned this Sep 11, 2024
@simonsays1980 simonsays1980 added the P2 Important issue, but not time-critical label Sep 11, 2024
@PhilippWillms
Copy link
Author

PhilippWillms commented Sep 13, 2024

@grizzlybearg : Sounds easy to implement for my own custom RLModules, but it should also be changed in the central repo managed by ray team.

@indigenica
Copy link

indigenica commented Sep 16, 2024

I had de same problem with action masking in new api stack with RLlib 2.35, solution proposed by @grizzlybearg is working for me too. Thanks!

modify youy class' _compute_values method to compute_values

in examples/rl_modules/classes/action_masking_rlm.py

@simonsays1980
Copy link
Collaborator

@PhilippWillms Nice catch! Thanks a ton! Fixed in the related PR - waiting for tests to pass.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something that is supposed to be working; but isn't P2 Important issue, but not time-critical rllib RLlib related issues rllib-docs-or-examples Issues related to RLlib documentation or rllib/examples
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants