Skip to content

When converting a traced torchvision model AssertionError: type_inference: axis=0, i=1: 256 != is452 #1795

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
ivyas21 opened this issue Mar 7, 2023 · 1 comment
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@ivyas21
Copy link

ivyas21 commented Mar 7, 2023

When converting a traced torchvision model, After applying roi_align from #1509 AssertionError: type_inference: axis=0, i=1: 256 != is452

Stack Trace

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_31355/3386583322.py in <module>
      5     traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
      6 
----> 7 detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
      8 detector_mlmodel.save("segmenter.mlmodel")

/opt/conda/lib/python3.7/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
    454         package_dir=package_dir,
    455         debug=debug,
--> 456         specification_version=specification_version,
    457     )
    458 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    185         See `coremltools.converters.convert`
    186     """
--> 187     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    188 
    189 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    214                             convert_to,
    215                             registry,
--> 216                             **kwargs
    217                          )
    218 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    279     frontend_converter = frontend_converter_type()
    280 
--> 281     prog = frontend_converter(model, **kwargs)
    282 
    283     if convert_to.lower() != "neuralnetwork":

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
    107         from .frontend.torch import load
    108 
--> 109         return load(*args, **kwargs)
    110 
    111 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
     55     inputs = _convert_to_torch_inputtype(inputs)
     56     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57     return _perform_torch_convert(converter, debug)
     58 
     59 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
     94 def _perform_torch_convert(converter, debug):
     95     try:
---> 96         prog = converter.convert()
     97     except RuntimeError as e:
     98         if debug and "convert function" in str(e):

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    279 
    280             # Add the rest of the operations
--> 281             convert_nodes(self.context, self.graph)
    282 
    283             graph_outputs = [self.context[name] for name in self.graph.outputs]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     87 
     88         context.prepare_for_conversion(node)
---> 89         add_op(context, node)
     90 
     91         # We've generated all the outputs the graph needs, terminate conversion.

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in scatter(context, node)
   5228         mode = 'update'
   5229 
-> 5230     _scatter(context, inputs, mode, node.name)
   5231 
   5232 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in _scatter(context, inputs, mode, name)
   5209     if types.is_scalar(updates.sym_type):
   5210         updates = mb.fill(shape=indices.shape, value=updates.val, name=name)
-> 5211     result = mb.scatter_along_axis(data=data, indices=indices, updates=updates,axis=axis, mode=mode, name=name)
   5212     context.add(result)
   5213 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
    174                     op_cls_to_add = op_reg[op_type]
    175 
--> 176                 return cls._add_op(op_cls_to_add, **kwargs)
    177 
    178             setattr(Builder, op_type, add_op)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    180         curr_block()._insert_op_before(new_op, before_op=before_op)
    181         new_op.build_nested_blocks()
--> 182         new_op.type_value_inference()
    183         if len(new_op.outputs) == 1:
    184             return new_op.outputs[0]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/operation.py in type_value_inference(self, overwrite_output)
    251         existing _output_vars
    252         """
--> 253         output_types = self.type_inference()
    254         if not isinstance(output_types, tuple):
    255             output_types = (output_types,)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py in type_inference(self)
    431         for i in range(self.data.rank):
    432             if i != axis:
--> 433                 assert self.data.shape[i] == self.indices.shape[i], f'type_inference: axis={axis}, i={i}: {self.data.shape[i]} != {self.indices.shape[i]}'
    434 
    435         return self.data.sym_type

AssertionError: type_inference: axis=0, i=1: 256 != is452

Steps To Reproduce

import coremltools as ct
import torch, torchvision
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
import requests
from PIL import Image
import numpy as np
from typing import Dict, Tuple, Optional

# Image conversion tools:
class PILToTensor(torch.nn.Module):
    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.pil_to_tensor(image)
        return image, target

class ConvertImageDtype(torch.nn.Module):
    def __init__(self, dtype: torch.dtype) -> None:
        super().__init__()
        self.dtype = dtype

    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.convert_image_dtype(image, self.dtype)
        return image, target

# Load the torchvision model
detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model = detector_model.eval()

# Get a sample image
toTensor = T.PILToTensor()
toFloatTensor = T.ConvertImageDtype(torch.float)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

example_image_np = np.array(example_image)
example_image_pt = toFloatTensor(toTensor(example_image))
example_image_pt = example_image_pt.unsqueeze(0)

# Run the sample through the model to demonstrate the model works
y = detector_model(example_image_pt)

# Make an adaptor to convert the model outputs to a tuple
class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
    """This adapter is only here to unbox the first output."""
    def __init__(self, model, w=2):
        super().__init__()
        self.model = model

    def forward(self, x):
        result = self.model(x)
        return result[0]['boxes'], result[0]['labels'], result[0]['scores']

adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)

# Trace and convert the model using coremltools
model_to_trace = adapted_detector_model
with torch.inference_mode():
    out = model_to_trace(example_image_pt)
    traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
    
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=example_image_pt.shape)])
detector_mlmodel.save("segmenter.mlmodel")

System environment:

  • coremltools version: 6.2
  • OS: Linux (Linux foohostname 4.19.0-23-cloud-amd64 #1 SMP Debian 4.19.269-1 (2022-12-20) x86_64 GNU/Linux)
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • Python: 3.7
    • PyTorch: 1.11.1+cu102
    • Other libraries installed as dependencies of coremltools:
Requirement already satisfied: coremltools==6.2 in /opt/conda/lib/python3.7/site-packages (6.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (4.64.1)
Requirement already satisfied: protobuf<=4.0.0,>=3.1.0 in /home/jupyter/.local/lib/python3.7/site-packages (from coremltools==6.2) (3.20.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (21.3)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.21.6)
Requirement already satisfied: sympy in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.10.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->coremltools==6.2) (3.0.9)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.7/site-packages (from sympy->coremltools==6.2) (1.2.1)

Please advise. Thank you!

@ivyas21 ivyas21 added the bug Unexpected behaviour that should be corrected (type) label Mar 7, 2023
@junpeiz
Copy link
Collaborator

junpeiz commented Mar 8, 2023

Thank you for providing the detailed steps for reproducing it! Seems like a bug in mb.scatter_along_axis's value inference when input has symbolic shape.

@junpeiz junpeiz added the triaged Reviewed and examined, release as been assigned if applicable (status) label Mar 8, 2023
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

2 participants