diff --git a/tests/beit/test_modeling_beit.py b/tests/beit/test_modeling_beit.py index 3f375d3a31006a..59776f83553bfb 100644 --- a/tests/beit/test_modeling_beit.py +++ b/tests/beit/test_modeling_beit.py @@ -244,13 +244,7 @@ def test_training(self): # we don't test BeitForMaskedImageModeling if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]: continue - # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING - # this can then be incorporated into _prepare_for_class in test_modeling_common.py - elif model_class.__name__ == "BeitForSemanticSegmentation": - batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape - inputs_dict["labels"] = torch.zeros( - [self.model_tester.batch_size, height, width], device=torch_device - ).long() + model = model_class(config) model.to(torch_device) model.train() diff --git a/tests/segformer/test_modeling_segformer.py b/tests/segformer/test_modeling_segformer.py index 3c3f6ee5b4367d..668298507871e3 100644 --- a/tests/segformer/test_modeling_segformer.py +++ b/tests/segformer/test_modeling_segformer.py @@ -316,13 +316,7 @@ def test_training(self): for model_class in self.all_model_classes: if model_class in get_values(MODEL_MAPPING): continue - # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING - # this can then be incorporated into _prepare_for_class in test_modeling_common.py - if model_class.__name__ == "SegformerForSemanticSegmentation": - batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape - inputs_dict["labels"] = torch.zeros( - [self.model_tester.batch_size, height, width], device=torch_device - ).long() + model = model_class(config) model.to(torch_device) model.train()