Skip to content

Commit

Permalink
[RLlib] Cleanup evaluation folder. (ray-project#48493)
Browse files Browse the repository at this point in the history
Signed-off-by: mohitjain2504 <mohit.jain@dream11.com>
  • Loading branch information
sven1977 authored and mohitjain2504 committed Nov 15, 2024
1 parent 82617ea commit 470aa52
Show file tree
Hide file tree
Showing 45 changed files with 164 additions and 2,005 deletions.
28 changes: 0 additions & 28 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1367,14 +1367,6 @@ py_test(
srcs = ["evaluation/tests/test_env_runner_v2.py"]
)

# @OldAPIStack
py_test(
name = "evaluation/tests/test_episode",
tags = ["team:rllib", "evaluation"],
size = "small",
srcs = ["evaluation/tests/test_episode.py"]
)

# @OldAPIStack
py_test(
name = "evaluation/tests/test_episode_v2",
Expand Down Expand Up @@ -3181,26 +3173,6 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=7.2"]
)

#@OldAPIStack
py_test(
name = "examples/centralized_critic_2_tf",
main = "examples/centralized_critic_2.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/centralized_critic_2.py"],
args = ["--as-test", "--framework=tf", "--stop-reward=6.0"]
)

#@OldAPIStack
py_test(
name = "examples/centralized_critic_2_torch",
main = "examples/centralized_critic_2.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/centralized_critic_2.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=6.0"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2",
main = "examples/custom_recurrent_rnn_tokenizer.py",
Expand Down
89 changes: 38 additions & 51 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.env_runner_group import EnvRunnerGroup
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import (
collect_episodes,
summarize_episodes,
Expand Down Expand Up @@ -1634,10 +1633,7 @@ def restore_workers(self, workers: EnvRunnerGroup) -> None:
# worker of an EnvRunnerGroup.
state = from_worker.get_state()
# Take out (old) connector states from local worker's state.
if (
self.config.enable_connectors
and not self.config.enable_env_runner_and_connector_v2
):
if not self.config.enable_env_runner_and_connector_v2:
for pol_states in state["policy_states"].values():
pol_states.pop("connector_configs", None)
state_ref = ray.put(state)
Expand Down Expand Up @@ -2040,7 +2036,7 @@ def compute_single_action(
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episode: Optional[Episode] = None,
episode=None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
# Kwargs placeholder for future compatibility.
Expand Down Expand Up @@ -2127,53 +2123,44 @@ def compute_single_action(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Algorithm's local worker!"
)
local_worker = self.env_runner_group.local_env_runner

if not self.config.get("enable_connectors"):
# Check the preprocessor and preprocess, if necessary.
pp = local_worker.preprocessors[policy_id]
if pp and type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
observation = local_worker.filters[policy_id](observation, update=False)
else:
# Just preprocess observations, similar to how it used to be done before.
pp = policy.agent_connectors[ObsPreprocessorConnector]

# convert the observation to array if possible
if not isinstance(observation, (np.ndarray, dict, tuple)):
try:
observation = np.asarray(observation)
except Exception:
# Just preprocess observations, similar to how it used to be done before.
pp = policy.agent_connectors[ObsPreprocessorConnector]

# convert the observation to array if possible
if not isinstance(observation, (np.ndarray, dict, tuple)):
try:
observation = np.asarray(observation)
except Exception:
raise ValueError(
f"Observation type {type(observation)} cannot be converted to "
f"np.ndarray."
)
if pp:
assert len(pp) == 1, "Only one preprocessor should be in the pipeline"
pp = pp[0]

if not pp.is_identity():
# Note(Kourosh): This call will leave the policy's connector
# in eval mode. would that be a problem?
pp.in_eval()
if observation is not None:
_input_dict = {Columns.OBS: observation}
elif input_dict is not None:
_input_dict = {Columns.OBS: input_dict[Columns.OBS]}
else:
raise ValueError(
f"Observation type {type(observation)} cannot be converted to "
f"np.ndarray."
"Either observation or input_dict must be provided."
)
if pp:
assert len(pp) == 1, "Only one preprocessor should be in the pipeline"
pp = pp[0]

if not pp.is_identity():
# Note(Kourosh): This call will leave the policy's connector
# in eval mode. would that be a problem?
pp.in_eval()
if observation is not None:
_input_dict = {Columns.OBS: observation}
elif input_dict is not None:
_input_dict = {Columns.OBS: input_dict[Columns.OBS]}
else:
raise ValueError(
"Either observation or input_dict must be provided."
)

# TODO (Kourosh): Create a new util method for algorithm that
# computes actions based on raw inputs from env and can keep track
# of its own internal state.
acd = AgentConnectorDataType("0", "0", _input_dict)
# make sure the state is reset since we are only applying the
# preprocessor
pp.reset(env_id="0")
ac_o = pp([acd])[0]
observation = ac_o.data[Columns.OBS]
# TODO (Kourosh): Create a new util method for algorithm that
# computes actions based on raw inputs from env and can keep track
# of its own internal state.
acd = AgentConnectorDataType("0", "0", _input_dict)
# make sure the state is reset since we are only applying the
# preprocessor
pp.reset(env_id="0")
ac_o = pp([acd])[0]
observation = ac_o.data[Columns.OBS]

# Input-dict.
if input_dict is not None:
Expand Down Expand Up @@ -2225,7 +2212,7 @@ def compute_actions(
full_fetch: bool = False,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
episodes: Optional[List[Episode]] = None,
episodes=None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
**kwargs,
Expand Down
38 changes: 13 additions & 25 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.evaluation.episode import Episode as OldEpisode
from ray.rllib.utils.typing import EpisodeType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -370,7 +370,6 @@ def __init__(self, algo_class: Optional[type] = None):
self.observation_filter = "NoFilter"
self.update_worker_filter_stats = True
self.use_worker_filter_stats = True
self.enable_connectors = True
self.sampler_perf_stats_ema_coef = None

# `self.learners()`
Expand Down Expand Up @@ -572,6 +571,7 @@ def __init__(self, algo_class: Optional[type] = None):
# TODO: Remove, once all deprecation_warning calls upon using these keys
# have been removed.
# === Deprecated keys ===
self.enable_connectors = DEPRECATED_VALUE
self.simple_optimizer = DEPRECATED_VALUE
self.monitor = DEPRECATED_VALUE
self.evaluation_num_episodes = DEPRECATED_VALUE
Expand Down Expand Up @@ -1758,7 +1758,6 @@ def env_runners(
exploration_config: Optional[dict] = NotProvided, # @OldAPIStack
create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack
sample_collector: Optional[Type[SampleCollector]] = NotProvided, # @OldAPIStack
enable_connectors: Optional[bool] = NotProvided, # @OldAPIStack
remote_worker_envs: Optional[bool] = NotProvided, # @OldAPIStack
remote_env_batch_wait_ms: Optional[float] = NotProvided, # @OldAPIStack
preprocessor_pref: Optional[str] = NotProvided, # @OldAPIStack
Expand All @@ -1776,6 +1775,8 @@ def env_runners(
worker_health_probe_timeout_s=DEPRECATED_VALUE,
worker_restore_timeout_s=DEPRECATED_VALUE,
synchronize_filter=DEPRECATED_VALUE,
# deprecated
enable_connectors=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the rollout worker configuration.
Expand Down Expand Up @@ -1822,9 +1823,6 @@ def env_runners(
because it doesn't have to sample (done by remote_workers;
worker_indices > 0) nor evaluate (done by evaluation workers;
see below).
enable_connectors: Use connector based environment runner, so that all
preprocessing of obs and postprocessing of actions are done in agent
and action connectors.
env_to_module_connector: A callable taking an Env as input arg and returning
an env-to-module ConnectorV2 (might be a pipeline) object.
module_to_env_connector: A callable taking an Env and an RLModule as input
Expand Down Expand Up @@ -1933,29 +1931,29 @@ def env_runners(
Returns:
This updated AlgorithmConfig object.
"""
if enable_connectors != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(enable_connectors=...)",
error=False,
)
if num_rollout_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(num_rollout_workers)",
new="AlgorithmConfig.env_runners(num_env_runners)",
error=False,
error=True,
)
self.num_env_runners = num_rollout_workers
if num_envs_per_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(num_envs_per_worker)",
new="AlgorithmConfig.env_runners(num_envs_per_env_runner)",
error=False,
error=True,
)
self.num_envs_per_env_runner = num_envs_per_worker
if validate_workers_after_construction != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(validate_workers_after_construction)",
new="AlgorithmConfig.env_runners(validate_env_runners_after_"
"construction)",
error=False,
)
self.validate_env_runners_after_construction = (
validate_workers_after_construction
error=True,
)

if env_runner_cls is not NotProvided:
Expand Down Expand Up @@ -1987,8 +1985,6 @@ def env_runners(
self.sample_collector = sample_collector
if create_env_on_local_worker is not NotProvided:
self.create_env_on_local_worker = create_env_on_local_worker
if enable_connectors is not NotProvided:
self.enable_connectors = enable_connectors
if env_to_module_connector is not NotProvided:
self._env_to_module_connector = env_to_module_connector
if module_to_env_connector is not NotProvided:
Expand Down Expand Up @@ -2874,7 +2870,7 @@ def multi_agent(
] = NotProvided,
policy_map_capacity: Optional[int] = NotProvided,
policy_mapping_fn: Optional[
Callable[[AgentID, "OldEpisode"], PolicyID]
Callable[[AgentID, "EpisodeType"], PolicyID]
] = NotProvided,
policies_to_train: Optional[
Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
Expand Down Expand Up @@ -4466,14 +4462,6 @@ def _validate_new_api_stack_settings(self):
"to False (old API stack), instead."
)

# New API stack (RLModule, Learner APIs) only works with connectors.
if not self.enable_connectors:
raise ValueError(
"The new API stack (RLModule and Learner APIs) only works with "
"connectors! Please enable connectors via "
"`config.env_runners(enable_connectors=True)`."
)

# LR-schedule checking.
Scheduler.validate(
fixed_value_or_schedule=self.lr,
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
VTraceClipGradients,
VTraceOptimizer,
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_bootstrap_value,
compute_gae_for_sample_batch,
Expand Down Expand Up @@ -362,7 +361,7 @@ def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[SampleBatch] = None,
episode: Optional["Episode"] = None,
episode=None,
):
# Call super's postprocess_trajectory first.
# sample_batch = super().postprocess_trajectory(
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
make_time_major,
VTraceOptimizer,
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_bootstrap_value,
compute_gae_for_sample_batch,
Expand Down Expand Up @@ -378,7 +377,7 @@ def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
episode: Optional["Episode"] = None,
episode=None,
):
# Call super's postprocess_trajectory first.
# sample_batch = super().postprocess_trajectory(
Expand Down
15 changes: 7 additions & 8 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy import Policy
Expand Down Expand Up @@ -227,7 +226,7 @@ def on_episode_created(
self,
*,
# TODO (sven): Deprecate Episode/EpisodeV2 with new API stack.
episode: Union[EpisodeType, Episode, EpisodeV2],
episode: Union[EpisodeType, EpisodeV2],
# TODO (sven): Deprecate this arg new API stack (in favor of `env_runner`).
worker: Optional["EnvRunner"] = None,
env_runner: Optional["EnvRunner"] = None,
Expand Down Expand Up @@ -284,7 +283,7 @@ def on_episode_created(
def on_episode_start(
self,
*,
episode: Union[EpisodeType, Episode, EpisodeV2],
episode: Union[EpisodeType, EpisodeV2],
env_runner: Optional["EnvRunner"] = None,
metrics_logger: Optional[MetricsLogger] = None,
env: Optional[gym.Env] = None,
Expand Down Expand Up @@ -326,7 +325,7 @@ def on_episode_start(
def on_episode_step(
self,
*,
episode: Union[EpisodeType, Episode, EpisodeV2],
episode: Union[EpisodeType, EpisodeV2],
env_runner: Optional["EnvRunner"] = None,
metrics_logger: Optional[MetricsLogger] = None,
env: Optional[gym.Env] = None,
Expand Down Expand Up @@ -369,7 +368,7 @@ def on_episode_step(
def on_episode_end(
self,
*,
episode: Union[EpisodeType, Episode, EpisodeV2],
episode: Union[EpisodeType, EpisodeV2],
env_runner: Optional["EnvRunner"] = None,
metrics_logger: Optional[MetricsLogger] = None,
env: Optional[gym.Env] = None,
Expand Down Expand Up @@ -473,7 +472,7 @@ def on_postprocess_trajectory(
self,
*,
worker: "EnvRunner",
episode: Episode,
episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
Expand Down Expand Up @@ -603,7 +602,7 @@ def __init__(self):
def on_episode_end(
self,
*,
episode: Union[EpisodeType, Episode, EpisodeV2],
episode: Union[EpisodeType, EpisodeV2],
env_runner: Optional["EnvRunner"] = None,
metrics_logger: Optional[MetricsLogger] = None,
env: Optional[gym.Env] = None,
Expand Down Expand Up @@ -743,7 +742,7 @@ def on_postprocess_trajectory(
self,
*,
worker: "EnvRunner",
episode: Episode,
episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
Expand Down
Loading

0 comments on commit 470aa52

Please # to comment.