-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: improve test coverage and clear redundant code. #238
Conversation
if self.bias is not None: | ||
return torch.add(w_times_x, self.bias[:, None, :]) # w times x + b | ||
return w_times_x # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code exists some error.
If self.bias
is None
, then return torch.add(w_times_x, self.bias[:, None, :])
will take a slice of None
. So I use an if else
here.
Then, as mypy
thinks that the code will not be reached, I wonder line 165 may be used in future model-based code, so I just disable mypy
here.
|
||
def _log_when_not_update(self) -> None: | ||
"""Log when not update.""" | ||
self._logger.store( | ||
**{ | ||
'Loss/Loss_reward_critic': 0.0, | ||
'Loss/Loss_pi': 0.0, | ||
'Value/reward_critic': 0.0, | ||
}, | ||
) | ||
if self._cfgs.algo_cfgs.use_cost: | ||
self._logger.store( | ||
**{ | ||
'Loss/Loss_cost_critic': 0.0, | ||
'Value/cost_critic': 0.0, | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply unused, so I delete, the same as others.
'SafetyPointGoal1-v0_data_test': OfflineMeta( | ||
url='https://drive.google.com/file/d/1JPJ127bWM_Tdej0AEGoFAqFFG9mWtzsN/view?usp=share_link', | ||
sha256sum='417b580cd4ef8f05a66d54c5d996b35a23a0e6c8ff8bae06807313a638df2dc6', | ||
episode_length=1, | ||
), | ||
'SafetyPointGoal1-v0_data_init_test': OfflineMeta( | ||
url='https://drive.google.com/file/d/1WlfkoUvWuFUYVMlGwi_EdGO914oWndpV/view?usp=share_link', | ||
sha256sum='fce6cc1fd0c294a8b66397f2f5276c9e7055821ded1f3a6e58e491eb342b1fbe', | ||
episode_length=1, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test dataset
to test offline dataset. Each dataset is only 0.1MB.
with open( | ||
os.path.join( | ||
f'{custom_cfgs["logger_cfgs"]["log_dir"]}', | ||
terminal_log_name, | ||
), | ||
'w', | ||
encoding='utf-8', | ||
) | ||
# pylint: disable-next=consider-using-with | ||
sys.stderr = open( # noqa: SIM115 | ||
os.path.join(f'{custom_cfgs["logger_cfgs"]["log_dir"]}', error_log_name), | ||
'w', | ||
encoding='utf-8', | ||
) | ||
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) | ||
reward, cost, ep_len = agent.learn() | ||
) as f_out: | ||
sys.stdout = f_out | ||
with open( | ||
os.path.join( | ||
f'{custom_cfgs["logger_cfgs"]["log_dir"]}', | ||
error_log_name, | ||
), | ||
'w', | ||
encoding='utf-8', | ||
) as f_error: | ||
sys.stderr = f_error | ||
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) | ||
reward, cost, ep_len = agent.learn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use with open
instead of previous version to meet pylint
requirements.
@@ -225,6 +224,7 @@ def test_cem_based(algo): | |||
} | |||
agent = omnisafe.Agent(algo, env_id, custom_cfgs=custom_cfgs) | |||
agent.learn() | |||
agent.render() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As model-based algorithms have many branches when evaluating, I render all of them to make a trust-worthy test.
except yaml.YAMLError as exc: | ||
raise AssertionError(f'{path} error: {exc}') from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, here will raise FileNotFoundError
instead of yaml.YAMLError
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Description
Motivation and Context
Close issue #237
Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply:Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
make format
. (required)make lint
. (required)make test
pass. (required)