Skip to content

Commit

Permalink
Add missing tile recipes and various tile recipe changes (#3942)
Browse files Browse the repository at this point in the history
* add missing tile recipes

* Fix tiling XAI out of range (#3943)

- Fix tile merge XAI out of range

* update xai tile merge

* update rtdetr

* update tile recipes

* update rtdetr tile postprocess

* update rtdetr recipes and tile recipes

* update tile recipes

* fix rtdetr unittest

* update recipes

* refactor tile unit test

* address pr reviews

* remove unnecessary files

* update color channel

* fix image channel passing

* include tiling in cli integration test

* remove transform_bbox

---------

Co-authored-by: Vladislav Sovrasov <sovrasov.vlad@gmail.com>
  • Loading branch information
eugene123tw and sovrasov authored Sep 12, 2024
1 parent 8f96f27 commit 0f87c86
Show file tree
Hide file tree
Showing 24 changed files with 542 additions and 244 deletions.
19 changes: 12 additions & 7 deletions src/otx/algo/detection/base_models/detection_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,24 @@ def export(
if explain_mode:
msg = "Explain mode is not supported for DETR models yet."
raise NotImplementedError(msg)
return self.postprocess(self._forward_features(batch_inputs), deploy_mode=True)

return self.postprocess(
self._forward_features(batch_inputs),
[meta["img_shape"] for meta in batch_img_metas],
deploy_mode=True,
)

def postprocess(
self,
outputs: dict[str, Tensor],
original_size: tuple[int, int] | None = None,
original_sizes: list[tuple[int, int]],
deploy_mode: bool = False,
) -> dict[str, Tensor] | tuple[list[Tensor], list[Tensor], list[Tensor]]:
"""Post-processes the model outputs.
Args:
outputs (dict[str, Tensor]): The model outputs.
original_size (tuple[int, int], optional): The original size of the input images. Defaults to None.
original_sizes (list[tuple[int, int]]): The original image sizes.
deploy_mode (bool, optional): Whether to run in deploy mode. Defaults to False.
Returns:
Expand All @@ -120,9 +125,9 @@ def postprocess(

# convert bbox to xyxy and rescale back to original size (resize in OTX)
bbox_pred = box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
if not deploy_mode and original_size is not None:
original_size_tensor = torch.tensor(original_size).to(bbox_pred.device)
bbox_pred *= original_size_tensor.repeat(1, 2).unsqueeze(1)
if not deploy_mode:
original_size_tensor = torch.tensor(original_sizes).to(bbox_pred.device)
bbox_pred *= original_size_tensor.flip(1).repeat(1, 2).unsqueeze(1)

# perform scores computation and gather topk results
scores = nn.functional.sigmoid(logits)
Expand All @@ -136,7 +141,7 @@ def postprocess(

scores_list, boxes_list, labels_list = [], [], []

for sc, bb, ll in zip(scores, boxes, labels):
for sc, bb, ll, original_size in zip(scores, boxes, labels, original_sizes):
scores_list.append(sc)
boxes_list.append(
BoundingBoxes(bb, format="xyxy", canvas_size=original_size),
Expand Down
18 changes: 10 additions & 8 deletions src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ def _customize_inputs(
# prepare bboxes for the model
for bb, ll in zip(entity.bboxes, entity.labels):
# convert to cxcywh if needed
converted_bboxes = (
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
)
# normalize the bboxes
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
converted_bboxes.device,
)
if len(scaled_bboxes := bb):
converted_bboxes = (
box_convert(bb, in_fmt="xyxy", out_fmt="cxcywh") if bb.format == BoundingBoxFormat.XYXY else bb
)
# normalize the bboxes
scaled_bboxes = converted_bboxes / torch.tensor(bb.canvas_size[::-1]).tile(2)[None].to(
converted_bboxes.device,
)
targets.append({"boxes": scaled_bboxes, "labels": ll})

return {
Expand Down Expand Up @@ -109,7 +110,8 @@ def _customize_outputs(
raise TypeError(msg)
return losses

scores, bboxes, labels = self.model.postprocess(outputs, [img_info.img_shape for img_info in inputs.imgs_info])
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)

return DetBatchPredEntity(
batch_size=len(outputs),
Expand Down
7 changes: 5 additions & 2 deletions src/otx/algo/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from otx.algo.common.losses import CrossEntropyLoss, L1Loss
from otx.algo.detection.backbones import CSPDarknet
Expand Down Expand Up @@ -76,13 +76,16 @@ def _exporter(self) -> OTXModelExporter:
raise ValueError(msg)

swap_rgb = not isinstance(self, YOLOXTINY) # only YOLOX-TINY uses RGB
resize_mode: Literal["standard", "fit_to_window_letterbox"] = "fit_to_window_letterbox"
if self.tile_config.enable_tiler:
resize_mode = "standard"

return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, *self.input_size),
mean=self.mean,
std=self.std,
resize_mode="fit_to_window_letterbox",
resize_mode=resize_mode,
pad_value=114,
swap_rgb=swap_rgb,
via_onnx=True,
Expand Down
3 changes: 3 additions & 0 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
dataset.mem_cache_handler,
dataset.mem_cache_img_max_size,
dataset.max_refetch,
dataset.image_color_channel,
dataset.stack_images,
dataset.to_tv_image,
)
self.tile_config = tile_config
self._dataset = dataset
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/entity/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]:
labels=[[] for _ in range(self.batch_size)],
),
)
return list(zip(batch_tile_attr_list, batch_data_entities))
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))

@classmethod
def collate_fn(cls, batch_entities: list[TileDetDataEntity]) -> TileBatchDetDataEntity:
Expand Down Expand Up @@ -218,7 +218,7 @@ def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]:
)
for i in range(0, len(tiles), self.batch_size)
]
return list(zip(batch_tile_attr_list, batch_data_entities))
return list(zip(batch_tile_attr_list, batch_data_entities, strict=True))

@classmethod
def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchInstSegDataEntity:
Expand Down
16 changes: 9 additions & 7 deletions src/otx/core/utils/tile_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
img_infos (list[ImageInfo]): Original image information before tiling.
num_classes (int): Number of classes.
tile_config (TileConfig): Tile configuration.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
"""

def __init__(
Expand Down Expand Up @@ -119,8 +119,8 @@ def merge(
img_ids = []
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
batch_size = tile_preds.batch_size
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
batch_size = len(tile_attrs)
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
Expand All @@ -131,6 +131,7 @@ def merge(
tile_preds.scores,
saliency_maps,
feature_vectors,
strict=True,
):
offset_x, offset_y, _, _ = tile_attr["roi"]
tile_bboxes[:, 0::2] += offset_x
Expand All @@ -156,7 +157,7 @@ def merge(

return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]

def _merge_entities(
Expand Down Expand Up @@ -319,8 +320,8 @@ def merge(
img_ids = []
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
tile_attrs,
tile_preds.imgs_info,
Expand All @@ -329,6 +330,7 @@ def merge(
tile_preds.scores,
tile_preds.masks,
feature_vectors,
strict=True,
):
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
Expand Down Expand Up @@ -363,7 +365,7 @@ def merge(

return [
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
for img_id, image_info in zip(img_ids, self.img_infos)
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
]

def _merge_entities(
Expand Down
82 changes: 82 additions & 0 deletions src/otx/recipe/_base_/data/detection_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
task: DETECTION
input_size:
- 800
- 800
mem_cache_size: 1GB
mem_cache_img_max_size: null
image_color_channel: RGB
stack_images: true
data_format: coco_instances
unannotated_items_ratio: 0.0
tile_config:
enable_tiler: true
enable_adaptive_tiling: true
train_subset:
subset_name: train
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

val_subset:
subset_name: val
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

test_subset:
subset_name: test
transform_lib_type: TORCHVISION
batch_size: 1
num_workers: 2
to_tv_image: false
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale: $(input_size)
keep_ratio: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler
6 changes: 1 addition & 5 deletions src/otx/recipe/detection/atss_mobilenetv2_tile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/detection.yaml
data: ../_base_/data/detection_tile.yaml
overrides:
gradient_clip_val: 35.0
data:
tile_config:
enable_tiler: true
enable_adaptive_tiling: true

train_subset:
batch_size: 8
sampler:
Expand Down
51 changes: 51 additions & 0 deletions src/otx/recipe/detection/atss_resnext101_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
model:
class_path: otx.algo.detection.atss.ResNeXt101ATSS
init_args:
label_info: 80

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.004
momentum: 0.9
weight_decay: 0.0001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: DETECTION
device: auto

callback_monitor: val/map_50

data: ../_base_/data/detection_tile.yaml
overrides:
gradient_clip_val: 35.0
callbacks:
- class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling
init_args:
max_interval: 5
decay: -0.025
min_lrschedule_patience: 3

data:
train_subset:
batch_size: 4
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 4

test_subset:
batch_size: 4
2 changes: 0 additions & 2 deletions src/otx/recipe/detection/rtdetr_101.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ overrides:
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
Expand All @@ -103,7 +102,6 @@ overrides:
init_args:
scale: $(input_size)
keep_ratio: false
transform_bbox: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
Expand Down
Loading

0 comments on commit 0f87c86

Please # to comment.