Skip to content

Commit

Permalink
Clean up semantic segmentation tests (huggingface#16801)
Browse files Browse the repository at this point in the history
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
2 people authored and elusenji committed Jun 12, 2022
1 parent 597772a commit 6d9291b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
8 changes: 1 addition & 7 deletions tests/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,7 @@ def test_training(self):
# we don't test BeitForMaskedImageModeling
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()

model = model_class(config)
model.to(torch_device)
model.train()
Expand Down
8 changes: 1 addition & 7 deletions tests/segformer/test_modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,7 @@ def test_training(self):
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
if model_class.__name__ == "SegformerForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()

model = model_class(config)
model.to(torch_device)
model.train()
Expand Down

0 comments on commit 6d9291b

Please # to comment.