1
- from typing import Iterable , List , NamedTuple , Optional , Sequence , Tuple
1
+ from typing import Any , Iterable , List , NamedTuple , Optional , Sequence , Tuple
2
2
3
3
import torch
4
4
@@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
18
18
# is the dynamic batch dimension. Otherwise, we use the additional
19
19
# inputs to determine the batch dimension.
20
20
if additional_inputs is None :
21
+ batch_dims = None
22
+ if not isinstance (inputs , torch .Tensor ) and len (inputs ) > 1 :
23
+ bs = inputs [0 ].size (0 )
24
+ batch_dims = None
25
+ if not all (x .size (0 ) == bs for x in inputs ):
26
+ batch_dims = InputTensorSpec .find_batch_size_dim (inputs )
21
27
return InputTensorSpec .from_tensors_with_dynamic_batch_size (
22
28
inputs ,
23
29
(
@@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
26
32
lower_setting .max_batch_size ,
27
33
),
28
34
lower_setting .opt_profile_replica ,
35
+ batch_dims ,
29
36
)
30
37
else :
31
38
batch_dims = []
@@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size(
147
154
A list of InputTensorSpec named tuples with dynamic ranges.
148
155
"""
149
156
if batch_dims is None :
150
- batch_dims = [ 0 ] * len (tensors )
157
+ batch_dims = cls . find_batch_size_dim (tensors )
151
158
152
159
input_specs = []
153
160
batch_size = tensors [0 ].size (batch_dims [0 ])
154
161
155
162
for i , tensor in enumerate (tensors ):
156
163
batch_dim = batch_dims [i ]
157
- assert batch_size == tensor .size (
158
- batch_dim
159
- ), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
160
- shape = list (tensor .shape )
161
- shape [batch_dim ] = - 1
162
- shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [0 :batch_dim ] + [bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
163
- input_specs .append (
164
- cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
165
- )
164
+ if batch_dim == - 1 :
165
+ input_specs .append (cls .from_tensor (tensor ))
166
+ else :
167
+ shape = list (tensor .shape )
168
+ assert batch_size == tensor .size (
169
+ batch_dim
170
+ ), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
171
+ shape [batch_dim ] = - 1
172
+ shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [0 :batch_dim ] + [bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
173
+ input_specs .append (
174
+ cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
175
+ )
166
176
167
177
return input_specs
168
178
179
+ @classmethod
180
+ # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
181
+ def find_batch_size_dim (cls , inputs : Any ) -> []:
182
+ if isinstance (inputs , torch .Tensor ) or len (inputs ) <= 1 :
183
+ return [0 ]
184
+ shapes = [i .shape for i in inputs ]
185
+ frequency_map = {}
186
+ first_dims = set ()
187
+ for shape in shapes :
188
+ if len (shape ) < 2 :
189
+ # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
190
+ continue
191
+ # Dedup shape value for single tensor
192
+ first_dims .add (shape [0 ])
193
+ shape = set (shape )
194
+ for i in shape :
195
+ frequency_map [i ] = frequency_map .get (i , 0 ) + 1
196
+
197
+ if len (first_dims ) == 1 :
198
+ # first dim is the same in every input: we use it as batch_size
199
+ batch_size = first_dims .pop ()
200
+ elif frequency_map :
201
+ # first dims are different: we use the most frequent dim as batch_size
202
+ sorted_frequency = sorted (frequency_map .items (), key = lambda x : - x [1 ])
203
+ batch_size = sorted_frequency [0 ][0 ]
204
+ else :
205
+ # no dims to sort: no batch_size
206
+ batch_size = - 1
207
+
208
+ bs_dim = []
209
+ for i in inputs :
210
+ # Default batch size dim = -1, indicate no batch_size
211
+ dim = - 1
212
+ for index , val in enumerate (i .shape ):
213
+ if val == batch_size :
214
+ dim = index
215
+ break
216
+ bs_dim .append (dim )
217
+
218
+ return bs_dim
219
+
169
220
def to_random_tensor (self , id = 1 ):
170
221
shape = tuple (self .shape )
171
222
if len (get_dynamic_dims (shape )):
0 commit comments