13
13
# limitations under the License.
14
14
15
15
16
- from typing import Any , Dict , List , Union
16
+ from typing import Any , Dict , List , Optional , Union
17
17
18
18
from deepsparse .v2 .operators import Operator
19
19
from deepsparse .v2 .routers import Router
@@ -51,13 +51,13 @@ def __init__(
51
51
# SchedulerGroup handles running all schedulers in order of priority
52
52
self ._scheduler_group = SchedulerGroup (self .schedulers )
53
53
54
- def run (self , op_input : Any , context : Context ):
54
+ def run (self , inp : Any , context : Optional [ Context ] ):
55
55
"""
56
56
Run through the operators using the provided router and scheduler. Update the
57
57
context to reflect each step of the router. The input to a given operator is the
58
58
output of the previous operator.
59
59
60
- :param op_input : input to the operator. expected to be of any type that is
60
+ :param inp : input to the operator. expected to be of any type that is
61
61
expected by the operator.
62
62
:param context: context to store the current the inputs, outputs, and operator
63
63
for each step of the router.
@@ -69,7 +69,7 @@ def run(self, op_input: Any, context: Context):
69
69
operator = self .ops [next_step ]
70
70
71
71
output_future = self ._scheduler_group .submit (
72
- operator = operator , operator_input = op_input , context = context
72
+ operator = operator , operator_input = inp , context = context
73
73
)
74
74
75
75
# wait for future to resolve
@@ -78,12 +78,12 @@ def run(self, op_input: Any, context: Context):
78
78
# update context
79
79
context .update (
80
80
operator = operator ,
81
- input = op_input ,
81
+ input = inp ,
82
82
output = operator_output ,
83
83
)
84
84
85
85
next_step = self .router .next (next_step , self .ops )
86
- op_input = operator_output
86
+ inp = operator_output
87
87
return operator_output , context
88
88
89
89
def __call__ (self , * args , return_context : bool = False , ** kwargs ):
@@ -104,8 +104,7 @@ def __call__(self, *args, return_context: bool = False, **kwargs):
104
104
)
105
105
106
106
pipeline_input = kwargs or args [0 ]
107
- context = Context ()
108
- pipeline_output , context = self .run (op_input = pipeline_input , context = context )
107
+ pipeline_output , context = self .run (inp = pipeline_input , context = Context ())
109
108
110
109
if return_context :
111
110
return pipeline_output , context
0 commit comments