Skip to content

Commit fe9ed2a

Browse files
committed
update args
1 parent 7060904 commit fe9ed2a

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

src/deepsparse/v2/operators/operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def has_output_schema(cls) -> bool:
5656
def __call__(
5757
self,
5858
*args,
59-
context: Optional[Context] = None,
59+
context: Context,
6060
**kwargs,
6161
) -> Any:
6262
"""

src/deepsparse/v2/pipeline.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any, Dict, List, Union
16+
from typing import Any, Dict, List, Optional, Union
1717

1818
from deepsparse.v2.operators import Operator
1919
from deepsparse.v2.routers import Router
@@ -51,13 +51,13 @@ def __init__(
5151
# SchedulerGroup handles running all schedulers in order of priority
5252
self._scheduler_group = SchedulerGroup(self.schedulers)
5353

54-
def run(self, op_input: Any, context: Context):
54+
def run(self, inp: Any, context: Optional[Context]):
5555
"""
5656
Run through the operators using the provided router and scheduler. Update the
5757
context to reflect each step of the router. The input to a given operator is the
5858
output of the previous operator.
5959
60-
:param op_input: input to the operator. expected to be of any type that is
60+
:param inp: input to the operator. expected to be of any type that is
6161
expected by the operator.
6262
:param context: context to store the current the inputs, outputs, and operator
6363
for each step of the router.
@@ -69,7 +69,7 @@ def run(self, op_input: Any, context: Context):
6969
operator = self.ops[next_step]
7070

7171
output_future = self._scheduler_group.submit(
72-
operator=operator, operator_input=op_input, context=context
72+
operator=operator, operator_input=inp, context=context
7373
)
7474

7575
# wait for future to resolve
@@ -78,12 +78,12 @@ def run(self, op_input: Any, context: Context):
7878
# update context
7979
context.update(
8080
operator=operator,
81-
input=op_input,
81+
input=inp,
8282
output=operator_output,
8383
)
8484

8585
next_step = self.router.next(next_step, self.ops)
86-
op_input = operator_output
86+
inp = operator_output
8787
return operator_output, context
8888

8989
def __call__(self, *args, return_context: bool = False, **kwargs):
@@ -104,8 +104,7 @@ def __call__(self, *args, return_context: bool = False, **kwargs):
104104
)
105105

106106
pipeline_input = kwargs or args[0]
107-
context = Context()
108-
pipeline_output, context = self.run(op_input=pipeline_input, context=context)
107+
pipeline_output, context = self.run(inp=pipeline_input, context=Context())
109108

110109
if return_context:
111110
return pipeline_output, context

0 commit comments

Comments
 (0)