Skip to content

Commit

Permalink
Fix torch compile to work with segnext (#4000)
Browse files Browse the repository at this point in the history
* Fix compile with segnext

* Fix compile with segnext2

* Add unit-tests
  • Loading branch information
harimkang authored Oct 8, 2024
1 parent db2eaa0 commit 7fa81d3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/otx/algo/segmentation/backbones/mscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def __init__(

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward function."""
b = x.shape[0]
outs = []

for i in range(self.num_stages):
Expand All @@ -424,7 +423,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
for blk in block:
x = blk(x, h, w)
x = norm(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
x = x.reshape(x.shape[0], h, w, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)

return outs
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/algo/segmentation/test_dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#

import pytest
import torch
from otx.algo.segmentation.dino_v2_seg import DinoV2Seg
from otx.core.exporter.base import OTXModelExporter
from torch._dynamo.testing import CompileCounter


class TestDinoV2Seg:
Expand All @@ -26,3 +28,22 @@ def test_optimization_config(self, fxt_dino_v2_seg):
assert isinstance(config, dict)
assert "model_type" in config
assert config["model_type"] == "transformer"

@pytest.mark.parametrize(
"model",
[
DinoV2Seg(model_name="dinov2_vits14", label_info=3),
],
)
def test_compiled_model(self, model):
# Set Compile Counter
torch._dynamo.reset()
cnt = CompileCounter()

# Set model compile setting
model.model = torch.compile(model.model, backend=cnt)

# Prepare inputs
x = torch.randn(1, 3, 560, 560)
model.model(x)
assert cnt.frame_count == 1
14 changes: 14 additions & 0 deletions tests/unit/algo/segmentation/test_huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity
from torch._dynamo.testing import CompileCounter

SKIP_TRANSFORMERS_TEST = False
try:
Expand Down Expand Up @@ -92,3 +93,16 @@ def test_set_input_size(self, mock_pretrainedconfig, mock_automodel):
)

assert mock_automodel.from_pretrained.call_args.kwargs["image_size"] == input_size[-1]

def test_compiled_model(self, fxt_seg_model):
# Set Compile Counter
torch._dynamo.reset()
cnt = CompileCounter()

# Set model compile setting
fxt_seg_model.model = torch.compile(fxt_seg_model.model, backend=cnt)

# Prepare inputs
x = torch.randn(1, 3, *fxt_seg_model.input_size)
fxt_seg_model.model(x)
assert cnt.frame_count == 1
23 changes: 23 additions & 0 deletions tests/unit/algo/segmentation/test_segnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@


import pytest
import torch
from otx.algo.segmentation.segnext import SegNext
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from torch._dynamo.testing import CompileCounter


class TestSegNext:
Expand All @@ -30,3 +32,24 @@ def test_optimization_config(self, fxt_segnext):
assert isinstance(config["ignored_scope"]["patterns"], list)
assert "types" in config["ignored_scope"]
assert isinstance(config["ignored_scope"]["types"], list)

@pytest.mark.parametrize(
"model",
[
SegNext(model_name="segnext_tiny", label_info=3),
SegNext(model_name="segnext_small", label_info=3),
SegNext(model_name="segnext_base", label_info=3),
],
)
def test_compiled_model(self, model):
# Set Compile Counter
torch._dynamo.reset()
cnt = CompileCounter()

# Set model compile setting
model.model = torch.compile(model.model, backend=cnt)

# Prepare inputs
x = torch.randn(1, 3, *model.input_size)
model.model(x)
assert cnt.frame_count == 1

0 comments on commit 7fa81d3

Please # to comment.