Skip to content

Commit 4e5f2af

Browse files
dsikkadbogunowicz
authored andcommitted
[Pipeline Refactor] Additional functionality, engine operator, linear router and image classification pipeline/operators/example (#1325)
* initial functionality and working example with image classification * remove testing image * update args * initial functionality and working example with image classification * remove testing image * pr comments * defines schemas for operators and test * add image classification test, PR comments * fix input/output handling in pipeline and operator base classes to be more generic; remove context * add additional operator input message * typo fix
1 parent aa18bac commit 4e5f2af

16 files changed

+709
-204
lines changed

src/deepsparse/v2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from .pipeline import *
1817
from .operators import *
18+
from .pipeline import *
1919
from .routers import *
2020
from .schedulers import *
2121
from .utils import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
from .postprocess_operator import *
17+
from .preprocess_operator import *
18+
19+
20+
from .pipeline import * # isort:skip
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import warnings
17+
from typing import Dict, Optional, Tuple, Union
18+
19+
from deepsparse.v2.image_classification.postprocess_operator import (
20+
ImageClassificationPostProcess,
21+
)
22+
from deepsparse.v2.image_classification.preprocess_operator import (
23+
ImageClassificationPreProcess,
24+
)
25+
from deepsparse.v2.operators.engine_operator import EngineOperator
26+
from deepsparse.v2.pipeline import Pipeline
27+
from deepsparse.v2.routers.router import LinearRouter
28+
from deepsparse.v2.schedulers.scheduler import OperatorScheduler
29+
30+
31+
_LOGGER = logging.getLogger(__name__)
32+
33+
__all__ = ["ImageClassificationPipeline"]
34+
35+
36+
class ImageClassificationPipeline(Pipeline):
37+
def __init__(
38+
self,
39+
model_path: str,
40+
engine_kwargs: Optional[Dict] = None,
41+
class_names: Union[None, str, Dict[str, str]] = None,
42+
image_size: Optional[Tuple[int]] = None,
43+
top_k: int = 1,
44+
):
45+
if not engine_kwargs:
46+
engine_kwargs = {}
47+
engine_kwargs["model_path"] = model_path
48+
elif engine_kwargs.get("model_path") != model_path:
49+
warnings.warn(f"Updating engine_kwargs to include {model_path}")
50+
51+
engine = EngineOperator(**engine_kwargs)
52+
preproces = ImageClassificationPreProcess(
53+
model_path=engine.model_path, image_size=image_size
54+
)
55+
postprocess = ImageClassificationPostProcess(
56+
top_k=top_k, class_names=class_names
57+
)
58+
59+
ops = [preproces, engine, postprocess]
60+
router = LinearRouter(end_route=len(ops))
61+
scheduler = [OperatorScheduler()]
62+
super().__init__(ops=ops, router=router, schedulers=scheduler)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from typing import Dict, List, Union
17+
18+
import numpy
19+
from pydantic import BaseModel, Field
20+
21+
from deepsparse.v2.operators import Operator
22+
23+
24+
class ImageClassificationOutput(BaseModel):
25+
"""
26+
Output model for image classification
27+
"""
28+
29+
labels: List[Union[int, str, List[int], List[str]]] = Field(
30+
description="List of labels, one for each prediction"
31+
)
32+
scores: List[Union[float, List[float]]] = Field(
33+
description="List of scores, one for each prediction"
34+
)
35+
36+
37+
__all__ = ["ImageClassificationPostProcess"]
38+
39+
40+
class ImageClassificationPostProcess(Operator):
41+
"""
42+
Image Classification post-processing Operator. This Operator is responsible for
43+
processing outputs from the engine and returning the classification results to
44+
the user, using the ImageClassifcationOutput structure.
45+
"""
46+
47+
input_schema = None
48+
output_schema = ImageClassificationOutput
49+
50+
def __init__(
51+
self, top_k: int = 1, class_names: Union[None, str, Dict[str, str]] = None
52+
):
53+
self.top_k = top_k
54+
if isinstance(class_names, str) and class_names.endswith(".json"):
55+
self._class_names = json.load(open(class_names))
56+
elif isinstance(class_names, dict):
57+
self._class_names = class_names
58+
else:
59+
self._class_names = None
60+
61+
def run(self, inp: "EngineOperatorOutputs", **kwargs) -> Dict: # noqa: F821
62+
labels, scores = [], []
63+
inp = inp.engine_outputs
64+
for prediction_batch in inp[0]:
65+
label = (-prediction_batch).argsort()[: self.top_k]
66+
score = prediction_batch[label]
67+
labels.append(label)
68+
scores.append(score.tolist())
69+
70+
if self._class_names is not None:
71+
labels = numpy.vectorize(self._class_names.__getitem__)(labels)
72+
labels = labels.tolist()
73+
74+
if isinstance(labels[0], numpy.ndarray):
75+
labels = [label.tolist() for label in labels]
76+
77+
if len(labels) == 1:
78+
labels = labels[0]
79+
scores = scores[0]
80+
81+
return {"scores": scores, "labels": labels}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, Optional, Tuple
16+
17+
import numpy
18+
import onnx
19+
from PIL import Image
20+
from torchvision import transforms
21+
22+
from deepsparse.image_classification.constants import (
23+
IMAGENET_RGB_MEANS,
24+
IMAGENET_RGB_STDS,
25+
)
26+
from deepsparse.pipelines.computer_vision import ComputerVisionSchema
27+
from deepsparse.v2.operators import Operator
28+
29+
30+
class ImageClassificationInput(ComputerVisionSchema):
31+
"""
32+
Input model for image classification
33+
"""
34+
35+
36+
__all__ = ["ImageClassificationPreProcess"]
37+
38+
39+
class ImageClassificationPreProcess(Operator):
40+
"""
41+
Image Classification pre-processing operator. This Operator is expected to process
42+
the user inputs and prepare them for the engine. Inputs to this Operator are
43+
expected to follow the ImageClassificationInput schema.
44+
"""
45+
46+
input_schema = ImageClassificationInput
47+
output_schema = None
48+
49+
def __init__(self, model_path: str, image_size: Optional[Tuple[int]] = None):
50+
self.model_path = model_path
51+
self._image_size = image_size or self._infer_image_size()
52+
non_rand_resize_scale = 256.0 / 224.0 # standard used
53+
self._pre_normalization_transforms = transforms.Compose(
54+
[
55+
transforms.Resize(
56+
tuple(
57+
[
58+
round(non_rand_resize_scale * size)
59+
for size in self._image_size
60+
]
61+
)
62+
),
63+
transforms.CenterCrop(self._image_size),
64+
]
65+
)
66+
67+
def run(self, inp: ImageClassificationInput, **kwargs) -> Dict:
68+
"""
69+
Pre-Process the Inputs for DeepSparse Engine
70+
71+
:param inputs: input model
72+
:return: list of preprocessed numpy arrays
73+
"""
74+
75+
if isinstance(inp.images, numpy.ndarray):
76+
image_batch = inp.images
77+
else:
78+
if isinstance(inp.images, str):
79+
inp.images = [inp.images]
80+
81+
image_batch = list(map(self._preprocess_image, inp.images))
82+
83+
# build batch
84+
image_batch = numpy.stack(image_batch, axis=0)
85+
86+
original_dtype = image_batch.dtype
87+
image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32)
88+
89+
if original_dtype == numpy.uint8:
90+
image_batch /= 255
91+
# normalize entire batch
92+
image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1))
93+
image_batch /= numpy.asarray(IMAGENET_RGB_STDS).reshape((-1, 3, 1, 1))
94+
95+
return {"engine_inputs": [image_batch]}
96+
97+
def _preprocess_image(self, image) -> numpy.ndarray:
98+
if isinstance(image, List):
99+
# image given as raw list
100+
image = numpy.asarray(image)
101+
if image.dtype == numpy.float32:
102+
# image is already processed, append and continue
103+
return image
104+
# assume raw image input
105+
# put image in PIL format for torchvision processing
106+
image = image.astype(numpy.uint8)
107+
if image.shape[0] < image.shape[-1]:
108+
# put channel last
109+
image = numpy.einsum("cwh->whc", image)
110+
image = Image.fromarray(image)
111+
elif isinstance(image, str):
112+
# load image from string filepath
113+
image = Image.open(image).convert("RGB")
114+
elif isinstance(image, numpy.ndarray):
115+
image = image.astype(numpy.uint8)
116+
if image.shape[0] < image.shape[-1]:
117+
# put channel last
118+
image = numpy.einsum("cwh->whc", image)
119+
image = Image.fromarray(image)
120+
121+
if not isinstance(image, Image.Image):
122+
raise ValueError(
123+
f"inputs to {self.__class__.__name__} must be a string image "
124+
"file path(s), a list representing a raw image, "
125+
"PIL.Image.Image object(s), or a numpy array representing"
126+
f"the entire pre-processed batch. Found {type(image)}"
127+
)
128+
129+
# apply resize and center crop
130+
image = self._pre_normalization_transforms(image)
131+
image_numpy = numpy.array(image)
132+
image.close()
133+
134+
# make channel first dimension
135+
image_numpy = image_numpy.transpose(2, 0, 1)
136+
return image_numpy
137+
138+
def _infer_image_size(self) -> Tuple[int, ...]:
139+
"""
140+
Infer and return the expected shape of the input tensor
141+
142+
:return: The expected shape of the input tensor from onnx graph
143+
"""
144+
onnx_model = onnx.load(self.model_path)
145+
input_tensor = onnx_model.graph.input[0]
146+
return (
147+
input_tensor.type.tensor_type.shape.dim[2].dim_value,
148+
input_tensor.type.tensor_type.shape.dim[3].dim_value,
149+
)

0 commit comments

Comments
 (0)