From 54422a7b64b10f1b73c70fd54fe559ea71b5bf23 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 26 Mar 2024 13:36:54 +0900 Subject: [PATCH] Update visual prompting refactoring to develop (#3193) * Update * Enable using `print_config` * Add visual prompting tutorial * Update unit tests * Update * precommit * Updates for unit test * Updates for integration tests * Fix ruff errors * Update docs * Fix --- .../algo/visual_prompting/segment_anything.py | 22 +- .../zero_shot_segment_anything.py | 190 ++++++------ src/otx/core/model/visual_prompting.py | 20 +- .../encoders/test_sam_image_encoder.py | 6 + tests/unit/core/conftest.py | 16 +- tests/unit/core/exporter/__init__.py | 2 + .../core/exporter/test_visual_prompting.py | 85 ++++++ .../unit/core/model/test_visual_prompting.py | 275 +++++++++++++++++- 8 files changed, 508 insertions(+), 108 deletions(-) create mode 100644 tests/unit/core/exporter/__init__.py create mode 100644 tests/unit/core/exporter/test_visual_prompting.py diff --git a/src/otx/algo/visual_prompting/segment_anything.py b/src/otx/algo/visual_prompting/segment_anything.py index 029c90e862e..3b3ad48e122 100644 --- a/src/otx/algo/visual_prompting/segment_anything.py +++ b/src/otx/algo/visual_prompting/segment_anything.py @@ -496,9 +496,25 @@ def __init__( scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable, metric: MetricCallable = VisualPromptingMetricCallable, torch_compile: bool = False, - **kwargs, - ): - self.config = {"backbone": backbone, **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], **kwargs} + freeze_image_encoder: bool = True, + freeze_prompt_encoder: bool = True, + freeze_mask_decoder: bool = False, + use_stability_score: bool = False, + return_single_mask: bool = True, + return_extra_metrics: bool = False, + stability_score_offset: float = 1.0, + ) -> None: + self.config = { + "backbone": backbone, + "freeze_image_encoder": freeze_image_encoder, + "freeze_prompt_encoder": freeze_prompt_encoder, + "freeze_mask_decoder": freeze_mask_decoder, + "use_stability_score": use_stability_score, + "return_single_mask": return_single_mask, + "return_extra_metrics": return_extra_metrics, + "stability_score_offset": stability_score_offset, + **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], + } super().__init__( num_classes=num_classes, optimizer=optimizer, diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 81189203c42..91f6bf07275 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -19,12 +19,9 @@ from torch import LongTensor, Tensor, nn from torch.nn import functional as F # noqa: N812 from torchvision import tv_tensors -from torchvision.tv_tensors import BoundingBoxes, Image +from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor -from otx.algo.visual_prompting.segment_anything import ( - DEFAULT_CONFIG_SEGMENT_ANYTHING, - SegmentAnything, -) +from otx.algo.visual_prompting.segment_anything import DEFAULT_CONFIG_SEGMENT_ANYTHING, SegmentAnything from otx.core.data.entity.base import OTXBatchLossEntity, Points from otx.core.data.entity.visual_prompting import ( ZeroShotVisualPromptingBatchDataEntity, @@ -230,8 +227,8 @@ def expand_reference_info(self, reference_feats: Tensor, new_largest_label: int) @torch.no_grad() def learn( self, - images: list[tv_tensors.Image], - processed_prompts: list[dict[int, list[tv_tensors.TVTensor]]], + images: list[Image], + processed_prompts: list[dict[int, list[TVTensor]]], reference_feats: Tensor, used_indices: Tensor, ori_shapes: list[Tensor], @@ -244,8 +241,8 @@ def learn( Currently, single batch is only supported. Args: - images (list[tv_tensors.Image]): List of given images for reference features. - processed_prompts (dict[int, list[tv_tensors.TVTensor]]): The class-wise prompts + images (list[Image]): List of given images for reference features. + processed_prompts (dict[int, list[TVTensor]]): The class-wise prompts processed at OTXZeroShotSegmentAnything._gather_prompts_with_labels. reference_feats (Tensor): Reference features for target prediction. used_indices (Tensor): To check which indices of reference features are validate. @@ -269,7 +266,7 @@ def learn( # TODO (sungchul): ensemble multi reference features (current : use merged masks) ref_mask = torch.zeros(*map(int, ori_shape), dtype=torch.uint8, device=image.device) for input_prompt in input_prompts: - if isinstance(input_prompt, tv_tensors.Mask): + if isinstance(input_prompt, Mask): # directly use annotation information as a mask ref_mask[input_prompt == 1] += 1 # TODO (sungchul): check if the mask is bool or int else: @@ -321,7 +318,7 @@ def learn( @torch.no_grad() def infer( self, - images: list[tv_tensors.Image], + images: list[Image], reference_feats: Tensor, used_indices: Tensor, ori_shapes: list[Tensor], @@ -334,7 +331,7 @@ def infer( Get target results by using reference features and target images' features. Args: - images (list[tv_tensors.Image]): Given images for target results. + images (list[Image]): Given images for target results. reference_feats (Tensor): Reference features for target prediction. used_indices (Tensor): To check which indices of reference features are validate. ori_shapes (list[Tensor]): Original image size. @@ -455,66 +452,73 @@ def _predict_masks( masks: Tensor logits: Tensor scores: Tensor - num_iter = 3 if is_cascade else 1 - for i in range(num_iter): - if i == 0: - # First-step prediction - mask_input = torch.zeros( - 1, - 1, - *(x * 4 for x in image_embeddings.shape[2:]), - device=image_embeddings.device, - ) - has_mask_input = self.has_mask_inputs[0].to(mask_input.device) - - elif i == 1: - # Cascaded Post-refinement-1 - # TODO (sungchul2): Fix the following ruff errors, ticket no. 135852 - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:473:21: F821 Undefined name `masks` - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:474:21: F821 Undefined name `logits` - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:475:21: F821 Undefined name `scores` - mask_input, best_masks = self._decide_cascade_results( - masks, # noqa: F821 - logits, # noqa: F821 - scores, # noqa: F821 - is_single=True, - ) - if best_masks.sum() == 0: - return best_masks - - has_mask_input = self.has_mask_inputs[1].to(mask_input.device) - - elif i == 2: - # Cascaded Post-refinement-2 - # TODO (sungchul2): Fix the following ruff errors, ticket no. 135852 - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:475:21: F821 Undefined name `masks` - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:476:21: F821 Undefined name `logits` - # src/otx/algo/visual_prompting/zero_shot_segment_anything.py:477:21: F821 Undefined name `scores` - mask_input, best_masks = self._decide_cascade_results(masks, logits, scores) # noqa: F821 - if best_masks.sum() == 0: - return best_masks - - has_mask_input = self.has_mask_inputs[1].to(mask_input.device) - coords = torch.nonzero(best_masks) - y, x = coords[:, 0], coords[:, 1] - box_coords = self._preprocess_coords( - torch.tensor([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=torch.float32, device=coords.device), - ori_shape, - self.image_size, + + # First-step prediction + mask_input = torch.zeros( + 1, + 1, + *(x * 4 for x in image_embeddings.shape[2:]), + device=image_embeddings.device, + ) + has_mask_input = self.has_mask_inputs[0].to(mask_input.device) + high_res_masks, scores, logits = self( + mode=mode, + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + mask_input=mask_input, + has_mask_input=has_mask_input, + ori_shape=ori_shape, + ) + masks = high_res_masks > self.mask_threshold + + if is_cascade: + for i in range(2): + if i == 0: + # Cascaded Post-refinement-1 + mask_input, best_masks = self._decide_cascade_results( + masks, + logits, + scores, + is_single=True, + ) + if best_masks.sum() == 0: + return best_masks + + has_mask_input = self.has_mask_inputs[1].to(mask_input.device) + + else: + # Cascaded Post-refinement-2 + mask_input, best_masks = self._decide_cascade_results(masks, logits, scores) + if best_masks.sum() == 0: + return best_masks + + has_mask_input = self.has_mask_inputs[1].to(mask_input.device) + coords = torch.nonzero(best_masks) + y, x = coords[:, 0], coords[:, 1] + box_coords = self._preprocess_coords( + torch.tensor( + [[[x.min(), y.min()], [x.max(), y.max()]]], + dtype=torch.float32, + device=coords.device, + ), + ori_shape, + self.image_size, + ) + point_coords = torch.cat((point_coords, box_coords), dim=1) + point_labels = torch.cat((point_labels, self.point_labels_box.to(point_labels.device)), dim=1) + + high_res_masks, scores, logits = self( + mode=mode, + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + mask_input=mask_input, + has_mask_input=has_mask_input, + ori_shape=ori_shape, ) - point_coords = torch.cat((point_coords, box_coords), dim=1) - point_labels = torch.cat((point_labels, self.point_labels_box.to(point_labels.device)), dim=1) + masks = high_res_masks > self.mask_threshold - high_res_masks, scores, logits = self( - mode=mode, - image_embeddings=image_embeddings, - point_coords=point_coords, - point_labels=point_labels, - mask_input=mask_input, - has_mask_input=has_mask_input, - ori_shape=ori_shape, - ) - masks = high_res_masks > self.mask_threshold _, best_masks = self._decide_cascade_results(masks, logits, scores) return best_masks @@ -623,17 +627,37 @@ def __init__( self, backbone: Literal["tiny_vit", "vit_b"], num_classes: int = 0, - root_reference_info: Path | str = "vpm_zsl_reference_infos", - save_outputs: bool = True, - pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006 - pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006 optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable, scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable, metric: MetricCallable = VisualPromptingMetricCallable, torch_compile: bool = False, - **kwargs, - ): - self.config = {"backbone": backbone, **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], **kwargs} + root_reference_info: Path | str = "vpm_zsl_reference_infos", + save_outputs: bool = True, + pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006 + pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006 + freeze_image_encoder: bool = True, + freeze_prompt_encoder: bool = True, + freeze_mask_decoder: bool = True, + default_threshold_reference: float = 0.3, + default_threshold_target: float = 0.65, + use_stability_score: bool = False, + return_single_mask: bool = False, + return_extra_metrics: bool = False, + stability_score_offset: float = 1.0, + ) -> None: + self.config = { + "backbone": backbone, + "freeze_image_encoder": freeze_image_encoder, + "freeze_prompt_encoder": freeze_prompt_encoder, + "freeze_mask_decoder": freeze_mask_decoder, + "default_threshold_reference": default_threshold_reference, + "default_threshold_target": default_threshold_target, + "use_stability_score": use_stability_score, + "return_single_mask": return_single_mask, + "return_extra_metrics": return_extra_metrics, + "stability_score_offset": stability_score_offset, + **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], + } super().__init__( num_classes=num_classes, optimizer=optimizer, @@ -729,7 +753,7 @@ def _customize_outputs( # type: ignore[override] self.used_indices = outputs[0].get("used_indices") return outputs - masks: list[tv_tensors.Mask] = [] + masks: list[Mask] = [] prompts: list[Points] = [] scores: list[Tensor] = [] labels: list[LongTensor] = [] @@ -737,7 +761,7 @@ def _customize_outputs( # type: ignore[override] for label, predicted_mask in predicted_masks.items(): if len(predicted_mask) == 0: continue - masks.append(tv_tensors.Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32)) + masks.append(Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32)) prompts.append( Points( torch.stack([p[:2] for p in used_points[label]], dim=0), @@ -761,11 +785,11 @@ def _customize_outputs( # type: ignore[override] def _gather_prompts_with_labels( self, - prompts: list[list[tv_tensors.TVTensor]], + prompts: list[list[TVTensor]], labels: list[Tensor], - ) -> list[dict[int, list[tv_tensors.TVTensor]]]: + ) -> list[dict[int, list[TVTensor]]]: """Gather prompts according to labels.""" - total_processed_prompts: list[dict[int, list[tv_tensors.TVTensor]]] = [] + total_processed_prompts: list[dict[int, list[TVTensor]]] = [] for prompt, label in zip(prompts, labels): processed_prompts = defaultdict(list) for _prompt, _label in zip(prompt, label): # type: ignore[arg-type] @@ -774,7 +798,7 @@ def _gather_prompts_with_labels( total_processed_prompts.append(sorted_processed_prompts) return total_processed_prompts - def apply_image(self, image: tv_tensors.Image | np.ndarray, target_length: int = 1024) -> tv_tensors.Image: + def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image: """Preprocess image to be used in the model.""" h, w = image.shape[-2:] target_size = self.get_preprocess_shape(h, w, target_length) diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 876b38c3e88..f5df0bcc28c 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -112,13 +112,13 @@ def _inference_step( ] _target = converted_entities["target"] _metric.update(preds=_preds, target=_target) - elif _name in ["IoU", "F1", "Dice"]: + elif _name in ["iou", "f1-score", "dice"]: # BinaryJaccardIndex, BinaryF1Score, Dice for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]): _metric.update(cvt_preds["masks"], cvt_target["masks"]) -def _inference_step_for_zeroshot( +def _inference_step_for_zero_shot( model: OTXZeroShotVisualPromptingModel | OVZeroShotVisualPromptingModel, metric: MetricCollection, inputs: ZeroShotVisualPromptingBatchDataEntity, @@ -160,7 +160,7 @@ def _inference_step_for_zeroshot( _preds.append(_preds[idx] if idx < len(_preds) else pad_prediction) _metric.update(preds=_preds, target=_target) - elif _name in ["IoU", "F1", "Dice"]: + elif _name in ["iou", "f1-score", "dice"]: # BinaryJaccardIndex, BinaryF1Score, Dice for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]): _metric.update( @@ -441,7 +441,7 @@ def test_step( Raises: TypeError: If the predictions are not of type VisualPromptingBatchPredEntity. """ - _inference_step_for_zeroshot(model=self, metric=self.metric, inputs=inputs) + _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) def _convert_pred_entity_to_compute_metric( self, @@ -789,6 +789,7 @@ def learn( inputs: ZeroShotVisualPromptingBatchDataEntity, reset_feat: bool = False, default_threshold_reference: float = 0.3, + is_cascade: bool = False, ) -> tuple[dict[str, np.ndarray], list[np.ndarray]]: """`Learn` for reference features.""" if reset_feat or self.reference_feats is None: @@ -815,7 +816,7 @@ def learn( if "point_coords" in inputs_decoder: # bboxes and points inputs_decoder.update(image_embeddings) - prediction = self._predict_masks(inputs_decoder, original_shape, is_cascade=False) + prediction = self._predict_masks(inputs_decoder, original_shape, is_cascade=is_cascade) masks = prediction["upscaled_masks"] else: log.warning("annotation and polygon will be supported.") @@ -847,7 +848,7 @@ def infer( inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: np.ndarray, used_indices: np.ndarray, - is_cascade: bool = False, + is_cascade: bool = True, threshold: float = 0.0, num_bg_points: int = 1, default_threshold_target: float = 0.65, @@ -1087,9 +1088,10 @@ def _predict_masks( has_mask_input = self.has_mask_inputs[1] y, x = np.nonzero(masks) box_coords = self.model["decoder"].apply_coords( - np.array([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=np.float32), - original_size[0], + np.array([[x.min(), y.min()], [x.max(), y.max()]], dtype=np.float32), + original_size, ) + box_coords = np.expand_dims(box_coords, axis=0) inputs.update( { "point_coords": np.concatenate((inputs["point_coords"], box_coords), axis=1), @@ -1419,7 +1421,7 @@ def test_step( Raises: TypeError: If the predictions are not of type VisualPromptingBatchPredEntity. """ - _inference_step_for_zeroshot(model=self, metric=self.metric, inputs=inputs) + _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) def _convert_pred_entity_to_compute_metric( self, diff --git a/tests/unit/algo/visual_prompting/encoders/test_sam_image_encoder.py b/tests/unit/algo/visual_prompting/encoders/test_sam_image_encoder.py index f1f87b81ffa..1a0f064e8a6 100644 --- a/tests/unit/algo/visual_prompting/encoders/test_sam_image_encoder.py +++ b/tests/unit/algo/visual_prompting/encoders/test_sam_image_encoder.py @@ -10,6 +10,7 @@ class TestSAMImageEncoder: ("backbone", "expected"), [ ("tiny_vit", "TinyViT"), + ("vit_b", "ViT"), ], ) def test_new(self, backbone: str, expected: str) -> None: @@ -17,3 +18,8 @@ def test_new(self, backbone: str, expected: str) -> None: sam_image_encoder = SAMImageEncoder(backbone=backbone) assert sam_image_encoder.__class__.__name__ == expected + + def test_new_unsupported_backbone(self) -> None: + """Test __new__ for unsupported backbone.""" + with pytest.raises(ValueError): # noqa: PT011 + SAMImageEncoder(backbone="unsupported_backbone") diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index bf19c9e1197..d20d49ddb45 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -193,7 +193,7 @@ def fxt_vpm_data_entity() -> ( tuple[VisualPromptingDataEntity, VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity] ): img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) + fake_image = tv_tensors.Image(torch.ones(img_size)) fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) fake_bboxes = tv_tensors.BoundingBoxes( [[0, 0, 1, 1]], @@ -202,9 +202,10 @@ def fxt_vpm_data_entity() -> ( dtype=torch.float32, ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) + fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) fake_labels = {"bboxes": torch.as_tensor([1], dtype=torch.int64)} fake_polygons = [None] + fake_scores = torch.tensor([[1.0]]) # define data entity single_data_entity = VisualPromptingDataEntity( image=fake_image, @@ -234,7 +235,7 @@ def fxt_vpm_data_entity() -> ( polygons=[fake_polygons], bboxes=[fake_bboxes], points=[fake_points], - scores=[], + scores=[fake_scores], ) return single_data_entity, batch_data_entity, batch_pred_data_entity @@ -249,7 +250,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( ] ): img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) + fake_image = tv_tensors.Image(torch.ones(img_size)) fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) fake_bboxes = tv_tensors.BoundingBoxes( [[0, 0, 1, 1]], @@ -258,9 +259,10 @@ def fxt_zero_shot_vpm_data_entity() -> ( dtype=torch.float32, ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) - fake_labels = torch.as_tensor([1], dtype=torch.int64) + fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) + fake_labels = torch.as_tensor([[1]], dtype=torch.int64) fake_polygons = [None] + fake_scores = torch.tensor([[1.0]]) # define data entity single_data_entity = ZeroShotVisualPromptingDataEntity( image=fake_image, @@ -287,7 +289,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( labels=[fake_labels], polygons=[fake_polygons], prompts=[[fake_bboxes, fake_points]], - scores=[], + scores=[fake_scores], ) return single_data_entity, batch_data_entity, batch_pred_data_entity diff --git a/tests/unit/core/exporter/__init__.py b/tests/unit/core/exporter/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/core/exporter/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/core/exporter/test_visual_prompting.py b/tests/unit/core/exporter/test_visual_prompting.py new file mode 100644 index 00000000000..e0249ab8f16 --- /dev/null +++ b/tests/unit/core/exporter/test_visual_prompting.py @@ -0,0 +1,85 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests of visual prompting exporter.""" + +import pytest +from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter +from otx.core.types.export import OTXExportFormatType +from torch import nn + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.image_encoder = nn.Identity() + self.embed_dim = 2 + self.image_embedding_size = 4 + + def forward(self, x): + return x + + +class TestOTXVisualPromptingModelExporter: + @pytest.fixture() + def otx_visual_prompting_model_exporter(self) -> OTXVisualPromptingModelExporter: + return OTXVisualPromptingModelExporter(input_size=(10, 10), via_onnx=True) + + def test_export_openvino(self, mocker, tmpdir, otx_visual_prompting_model_exporter) -> None: + """Test export for OPENVINO.""" + mocker_torch_onnx_export = mocker.patch("torch.onnx.export") + mocker_onnx_load = mocker.patch("onnx.load") + mocker_onnx_save = mocker.patch("onnx.save") + mocker_postprocess_onnx_model = mocker.patch.object( + otx_visual_prompting_model_exporter, + "_postprocess_onnx_model", + ) + mocker_openvino_convert_model = mocker.patch("openvino.convert_model") + mocker_postprocess_openvino_model = mocker.patch.object( + otx_visual_prompting_model_exporter, + "_postprocess_openvino_model", + ) + mocker_openvino_save_model = mocker.patch("openvino.save_model") + + otx_visual_prompting_model_exporter.export( + model=MockModel(), + output_dir=tmpdir, + export_format=OTXExportFormatType.OPENVINO, + ) + + mocker_torch_onnx_export.assert_called() + mocker_onnx_load.assert_called() + mocker_onnx_save.assert_called() + mocker_postprocess_onnx_model.assert_called() + mocker_openvino_convert_model.assert_called() + mocker_postprocess_openvino_model.assert_called() + mocker_openvino_save_model.assert_called() + + def test_export_onnx(self, mocker, tmpdir, otx_visual_prompting_model_exporter) -> None: + """Test export for ONNX.""" + mocker_torch_onnx_export = mocker.patch("torch.onnx.export") + mocker_onnx_load = mocker.patch("onnx.load") + mocker_onnx_save = mocker.patch("onnx.save") + mocker_postprocess_onnx_model = mocker.patch.object( + otx_visual_prompting_model_exporter, + "_postprocess_onnx_model", + ) + + otx_visual_prompting_model_exporter.export( + model=MockModel(), + output_dir=tmpdir, + export_format=OTXExportFormatType.ONNX, + ) + + mocker_torch_onnx_export.assert_called() + mocker_onnx_load.assert_called() + mocker_onnx_save.assert_called() + mocker_postprocess_onnx_model.assert_called() + + def test_export_exportable_code(self, tmpdir, otx_visual_prompting_model_exporter) -> None: + """Test export for EXPORTABLE_CODE.""" + with pytest.raises(NotImplementedError): + otx_visual_prompting_model_exporter.export( + model=MockModel(), + output_dir=tmpdir, + export_format=OTXExportFormatType.EXPORTABLE_CODE, + ) diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 18789d55278..d0b83d8a82b 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -11,22 +11,129 @@ import numpy as np import pytest import torch -from otx.core.data.entity.visual_prompting import VisualPromptingBatchPredEntity +from otx.core.data.entity.base import Points +from otx.core.data.entity.visual_prompting import ( + VisualPromptingBatchPredEntity, + ZeroShotVisualPromptingBatchDataEntity, + ZeroShotVisualPromptingBatchPredEntity, +) from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter from otx.core.model.visual_prompting import ( OTXVisualPromptingModel, + OTXZeroShotVisualPromptingModel, OVVisualPromptingModel, OVZeroShotVisualPromptingModel, + _inference_step, + _inference_step_for_zero_shot, ) from torchvision import tv_tensors -class TestOTXVisualPromptingModel: - @pytest.fixture() - def otx_visual_prompting_model(self, mocker) -> OTXVisualPromptingModel: - mocker.patch.object(OTXVisualPromptingModel, "_create_model") - return OTXVisualPromptingModel(num_classes=1) +@pytest.fixture() +def otx_visual_prompting_model(mocker) -> OTXVisualPromptingModel: + mocker.patch.object(OTXVisualPromptingModel, "_create_model") + return OTXVisualPromptingModel(num_classes=1) + + +@pytest.fixture() +def otx_zero_shot_visual_prompting_model(mocker) -> OTXZeroShotVisualPromptingModel: + mocker.patch.object(OTXZeroShotVisualPromptingModel, "_create_model") + return OTXZeroShotVisualPromptingModel(num_classes=1) + + +def test_inference_step(mocker, otx_visual_prompting_model, fxt_vpm_data_entity) -> None: + """Test _inference_step.""" + otx_visual_prompting_model.configure_metric() + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_vpm_data_entity[2]) + mocker_updates = {} + for k, v in otx_visual_prompting_model.metric.items(): + mocker_updates[k] = mocker.patch.object(v, "update") + + _inference_step(otx_visual_prompting_model, otx_visual_prompting_model.metric, fxt_vpm_data_entity[1]) + + for v in mocker_updates.values(): + v.assert_called_once() + + +def test_inference_step_for_zero_shot(mocker, otx_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: + """Test _inference_step_for_zero_shot.""" + otx_visual_prompting_model.configure_metric() + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + mocker_updates = {} + for k, v in otx_visual_prompting_model.metric.items(): + mocker_updates[k] = mocker.patch.object(v, "update") + _inference_step_for_zero_shot( + otx_visual_prompting_model, + otx_visual_prompting_model.metric, + fxt_zero_shot_vpm_data_entity[1], + ) + + for v in mocker_updates.values(): + v.assert_called_once() + + +def test_inference_step_for_zero_shot_with_more_preds( + mocker, + otx_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, +) -> None: + """Test _inference_step_for_zero_shot with more preds.""" + otx_visual_prompting_model.configure_metric() + preds = {} + for k, v in fxt_zero_shot_vpm_data_entity[2].__dict__.items(): + if k in ["batch_size", "polygons"]: + preds[k] = v + else: + preds[k] = v * 2 + mocker.patch.object( + otx_visual_prompting_model, + "forward", + return_value=ZeroShotVisualPromptingBatchPredEntity(**preds), + ) + mocker_updates = {} + for k, v in otx_visual_prompting_model.metric.items(): + mocker_updates[k] = mocker.patch.object(v, "update") + + _inference_step_for_zero_shot( + otx_visual_prompting_model, + otx_visual_prompting_model.metric, + fxt_zero_shot_vpm_data_entity[1], + ) + + for v in mocker_updates.values(): + v.assert_called_once() + + +def test_inference_step_for_zero_shot_with_more_target( + mocker, + otx_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, +) -> None: + """Test _inference_step_for_zero_shot with more target.""" + otx_visual_prompting_model.configure_metric() + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + mocker_updates = {} + for k, v in otx_visual_prompting_model.metric.items(): + mocker_updates[k] = mocker.patch.object(v, "update") + target = {} + for k, v in fxt_zero_shot_vpm_data_entity[1].__dict__.items(): + if k in ["batch_size"]: + target[k] = v + else: + target[k] = v * 2 + + _inference_step_for_zero_shot( + otx_visual_prompting_model, + otx_visual_prompting_model.metric, + ZeroShotVisualPromptingBatchDataEntity(**target), + ) + + for v in mocker_updates.values(): + v.assert_called_once() + + +class TestOTXVisualPromptingModel: def test_exporter(self, otx_visual_prompting_model) -> None: """Test _exporter.""" assert isinstance(otx_visual_prompting_model._exporter, OTXVisualPromptingModelExporter) @@ -65,6 +172,96 @@ def test_optimization_config(self, otx_visual_prompting_model) -> None: } +class TestOTXZeroShotVisualPromptingModel: + def test_exporter(self, otx_zero_shot_visual_prompting_model) -> None: + """Test _exporter.""" + assert isinstance(otx_zero_shot_visual_prompting_model._exporter, OTXVisualPromptingModelExporter) + + def test_export_parameters(self, otx_zero_shot_visual_prompting_model) -> None: + """Test _export_parameters.""" + otx_zero_shot_visual_prompting_model.model.image_size = 1024 + + export_parameters = otx_zero_shot_visual_prompting_model._export_parameters + + assert export_parameters["input_size"] == (1, 3, 1024, 1024) + assert export_parameters["resize_mode"] == "fit_to_window" + assert export_parameters["mean"] == (123.675, 116.28, 103.53) + assert export_parameters["std"] == (58.395, 57.12, 57.375) + + def test_optimization_config(self, otx_zero_shot_visual_prompting_model) -> None: + """Test _optimization_config.""" + optimization_config = otx_zero_shot_visual_prompting_model._optimization_config + + assert optimization_config == { + "model_type": "transformer", + "advanced_parameters": { + "activations_range_estimator_params": { + "min": { + "statistics_type": "QUANTILE", + "aggregator_type": "MIN", + "quantile_outlier_prob": "1e-4", + }, + "max": { + "statistics_type": "QUANTILE", + "aggregator_type": "MAX", + "quantile_outlier_prob": "1e-4", + }, + }, + }, + } + + def test_on_test_start(self, mocker, otx_zero_shot_visual_prompting_model) -> None: + """Test on_test_start.""" + otx_zero_shot_visual_prompting_model.load_latest_reference_info = Mock(return_value=False) + otx_zero_shot_visual_prompting_model.trainer = Mock() + mocker_run = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer.fit_loop, "run") + mocker_setup_data = mocker.patch.object( + otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, + "setup_data", + ) + mocker_reset = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, "reset") + + otx_zero_shot_visual_prompting_model.on_test_start() + + mocker_run.assert_called_once() + mocker_setup_data.assert_called_once() + mocker_reset.assert_called_once() + + def test_on_predict_start(self, mocker, otx_zero_shot_visual_prompting_model) -> None: + """Test on_predict_start.""" + otx_zero_shot_visual_prompting_model.load_latest_reference_info = Mock(return_value=False) + otx_zero_shot_visual_prompting_model.trainer = Mock() + mocker_run = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer.fit_loop, "run") + mocker_setup_data = mocker.patch.object( + otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, + "setup_data", + ) + mocker_reset = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, "reset") + + otx_zero_shot_visual_prompting_model.on_predict_start() + + mocker_run.assert_called_once() + mocker_setup_data.assert_called_once() + mocker_reset.assert_called_once() + + def test_on_train_epoch_end(self, mocker, tmpdir, otx_zero_shot_visual_prompting_model) -> None: + """Test on_train_epoch_end.""" + otx_zero_shot_visual_prompting_model.save_outputs = True + otx_zero_shot_visual_prompting_model.root_reference_info = tmpdir + otx_zero_shot_visual_prompting_model.reference_feats = torch.tensor(1) + otx_zero_shot_visual_prompting_model.used_indices = torch.tensor(1) + mocker_mkdir = mocker.patch("otx.core.model.visual_prompting.Path.mkdir") + mocker.patch("otx.core.model.visual_prompting.Path.open") + mocker_torch_save = mocker.patch("otx.core.model.visual_prompting.torch.save") + mocker_pickle_dump = mocker.patch("otx.core.model.visual_prompting.pickle.dump") + + otx_zero_shot_visual_prompting_model.on_train_epoch_end() + + mocker_mkdir.assert_called_once() + mocker_torch_save.assert_called_once() + mocker_pickle_dump.assert_called_once() + + class TestOVVisualPromptingModel: @pytest.fixture() def set_ov_visual_prompting_model(self, mocker): @@ -172,6 +369,26 @@ def ov_zero_shot_visual_prompting_model(self, mocker) -> OVZeroShotVisualPrompti mocker.patch.object(OVZeroShotVisualPromptingModel, "initialize_reference_info") return OVZeroShotVisualPromptingModel(num_classes=0, model_name="exported_model_decoder.xml") + @pytest.mark.parametrize("training", [True, False]) + def test_forward( + self, + mocker, + ov_zero_shot_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, + training: bool, + ) -> None: + """Test forward.""" + ov_zero_shot_visual_prompting_model.training = training + ov_zero_shot_visual_prompting_model.reference_feats = "reference_feats" + ov_zero_shot_visual_prompting_model.used_indices = "used_indices" + mocker_fn = mocker.patch.object(ov_zero_shot_visual_prompting_model, "learn" if training else "infer") + mocker_customize_outputs = mocker.patch.object(ov_zero_shot_visual_prompting_model, "_customize_outputs") + + ov_zero_shot_visual_prompting_model.forward(fxt_zero_shot_vpm_data_entity[1]) + + mocker_fn.assert_called_once() + mocker_customize_outputs.assert_called_once() + def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test learn.""" ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) @@ -304,6 +521,52 @@ def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ for pm, _ in zip(predicted_mask, used_points[label]): assert pm.shape == (1024, 1024) + def test_customize_outputs_training( + self, + ov_zero_shot_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, + ) -> None: + ov_zero_shot_visual_prompting_model.training = True + + outputs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] + + result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) + + assert result == outputs + + def test_customize_outputs_inference( + self, + ov_zero_shot_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, + ) -> None: + ov_zero_shot_visual_prompting_model.training = False + + outputs = [ + ({1: [[1, 2, 3], [4, 5, 6]]}, {1: [[7, 8, 9], [10, 11, 12]]}), + ({2: [[13, 14, 15], [16, 17, 18]]}, {2: [[19, 20, 21], [22, 23, 24]]}), + ] + + result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) + + assert isinstance(result, ZeroShotVisualPromptingBatchPredEntity) + assert result.batch_size == len(outputs) + assert result.images == fxt_zero_shot_vpm_data_entity[1].images + assert result.imgs_info == fxt_zero_shot_vpm_data_entity[1].imgs_info + + assert isinstance(result.masks, list) + assert all(isinstance(mask, tv_tensors.Mask) for mask in result.masks) + + assert isinstance(result.prompts, list) + assert all(isinstance(prompt, Points) for prompt in result.prompts) + + assert isinstance(result.scores, list) + assert all(isinstance(score, torch.Tensor) for score in result.scores) + + assert isinstance(result.labels, list) + assert all(isinstance(label, torch.LongTensor) for label in result.labels) + + assert result.polygons == [] + def test_gather_prompts_with_labels(self, ov_zero_shot_visual_prompting_model) -> None: """Test _gather_prompts_with_labels.""" batch_prompts = [