-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Included PyTorch model tests for CFRL #799
Included PyTorch model tests for CFRL #799
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #799 +/- ##
==========================================
+ Coverage 76.53% 79.52% +2.98%
==========================================
Files 72 73 +1
Lines 8224 8477 +253
==========================================
+ Hits 6294 6741 +447
+ Misses 1930 1736 -194
Flags with carried forward coverage won't be shown. Click here to find out more.
|
14d13ee
to
f5def51
Compare
@RobertSamoilescu there are two binary files called |
def __init__(self, input_dim: int, output_dim: int): | ||
super().__init__() | ||
self.fc1 = nn.Linear(input_dim, output_dim, bias=False) | ||
self.to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for doing this here (contrasting with other tests where cpu()
was called explicitly after model instance creation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above. Probably can be deleted and was just added for some consistency. Guess I will just leave it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say better to remove it unless there's an explicit reason to have it here, otherwise it's another potential line for dev confusion :)
@contextlib.contextmanager | ||
def reset_model(model: Model): | ||
model._reset_loss() | ||
model._reset_metrics() | ||
yield | ||
model._reset_loss() | ||
model._reset_metrics() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm my understanding - when used in a with
statement this would basically reset losses and metrics, hand over the control to the user insider the with
block, and after existing reset everything again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. That's correct.
""" Test if the train step return the appropriate statistics. """ | ||
# copy the model `state_dict` as it will be modified by the `train_step`. Not really necessary now, | ||
# but to avoid any errors in the future if it is the case. | ||
state_dict_cpy = deepcopy(unimodal_model.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the concern here because of fixture sharing between test functions? Could it be done in a more idiomatic way using "yield" fixtures and/or scoping the model fixture at a test-function level instead of module level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the concern is that the model is shared between tests. To be honest, I am not sure how I can re-write it with "yield". Changing the scope to function level will do the job, but what I tried to do is avoid creating a new object again and again for each test. I guess if we decide to move the fixture to function level, then I think there will be no need for the context manager to reset the metrics and the loss. Not entirely sure what's the best approach ... Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just initializing a model isn't a huge overhead so think function scope makes sense and we have better guarantees for test isolation.
assert (param.grad is None) or (torch.allclose(param, torch.zeros_like(param))) | ||
|
||
# compute prediction | ||
unimodal_model.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be reverted before exiting the test (also see previous comment on fixtures)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't really matter for those tests since the model has the same behavior in train and eval. Now that I think about it, probably it makes sense to set the scope to function level ...
model1 = UnimodalModel(input_dim=input_dim, output_dim=output_dim) | ||
model2 = UnimodalModel(input_dim=input_dim, output_dim=output_dim) | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest
actually ships with a tmp_path
fixture which can be used, saving some code and imports :) https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, nice tests! Left a few comments, particularly around fixture management.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one more comment.
* Wrote actor-critic tests. * Included autoencoder tests. * Implemented cfrl_models tests. * Included metrics tests. * validate prediction labels tests -- in progress * Finalized all tests. Before refactoring. * isort, included docs, consisten use of quotes. * pytest error checks, removed docstrings. * Refactored test_model * Minor refactoring * solve flake8 issues. * Included pytest-mock in the requirements/dev.txt * Improved tests for train_step, test_step, fit, and evaluate. * Addressed comments. TODO: decide fixture scope * Moved models' scope to function. * Removed global variables.
This PR includes the
PyTorch
model test forCFRL
to increase coverage (i.e. addresses #760)