Skip to content

Commit

Permalink
[Enhancement]: Clarify the return value of get_rewrite_outputs (open-…
Browse files Browse the repository at this point in the history
…mmlab#129)

* Modify function and its call

* fix typo
  • Loading branch information
SingleZombie authored Oct 15, 2021
1 parent 2bd6752 commit 26cf2bd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 39 deletions.
14 changes: 10 additions & 4 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
Returns:
Any: The outputs of model, decided by the backend wrapper.
bool: A flag indicate the type of outputs. If the flag is True, then
the outputs are backend output, otherwise they are outputs of wrapped
pytorch model.
"""
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
Expand Down Expand Up @@ -201,7 +204,7 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
# convert to engine
import mmdeploy.apis.tensorrt as trt_apis
if not trt_apis.is_available():
return ctx_outputs
return ctx_outputs, False
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
trt_apis.onnx2tensorrt(
'',
Expand All @@ -214,7 +217,7 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
elif backend == Backend.ONNXRUNTIME:
import mmdeploy.apis.onnxruntime as ort_apis
if not ort_apis.is_available():
return ctx_outputs
return ctx_outputs, False
backend_model = ort_apis.ORTWrapper(onnx_file_path, 0, None)
feature_list = []
backend_feats = {}
Expand All @@ -239,8 +242,11 @@ def get_rewrite_outputs(wrapped_model: nn.Module, model_inputs: dict,
else:
backend_feats[str(i)] = feature_list[i]
elif backend == Backend.NCNN:
return ctx_outputs
return ctx_outputs, False
else:
raise NotImplementedError(
f'Unimplemented backend type: {backend.value}')

with torch.no_grad():
backend_outputs = backend_model.forward(backend_feats)
return backend_outputs
return backend_outputs, True
30 changes: 19 additions & 11 deletions tests/test_mmdet/test_mmdet_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_multiclass_nms_static():
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_func,
model_inputs={
'boxes': boxes,
Expand Down Expand Up @@ -89,19 +89,23 @@ def delta2bbox(*args, **kwargs):
deltas = torch.rand(1, 5, 4)
original_outputs = delta2bbox(rois, deltas)

# wrap function to nn.Module, enable torch.onn.export
# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(delta2bbox)
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={
'rois': rois,
'deltas': deltas
},
deploy_cfg=deploy_cfg)

model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
if is_backend_output:
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
Expand All @@ -124,19 +128,23 @@ def tblr2bboxes(*args, **kwargs):
tblr = torch.rand(1, 5, 4)
original_outputs = tblr2bboxes(priors, tblr)

# wrap function to nn.Module, enable torch.onn.export
# wrap function to nn.Module, enable torch.onnx.export
wrapped_func = WrapFunction(tblr2bboxes)
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_func,
model_inputs={
'priors': priors,
'tblr': tblr
},
deploy_cfg=deploy_cfg)

model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(model_output, rewrite_output, rtol=1e-03, atol=1e-05)
if is_backend_output:
model_output = original_outputs.squeeze().cpu().numpy()
rewrite_output = rewrite_outputs[0].squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None


def test_distance2bbox():
Expand Down
55 changes: 31 additions & 24 deletions tests/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,25 @@ def test_anchor_head_get_bboxes(backend_type):
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
if is_backend_output:
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
Expand Down Expand Up @@ -215,21 +219,24 @@ def test_get_bboxes_of_fcos_head(backend_type):
'bbox_preds': bboxes,
'centernesses': centernesses
}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
if is_backend_output:
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None


def _replace_r50_with_r18(model):
Expand Down Expand Up @@ -273,7 +280,7 @@ def test_forward_of_base_detector_and_visualize(model_cfg_path):

img = torch.randn(1, 3, 64, 64)
rewrite_inputs = {'img': img}
rewrite_outputs = get_rewrite_outputs(
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
Expand Down

0 comments on commit 26cf2bd

Please # to comment.