@@ -20,8 +20,8 @@ def __init__(self, model):
20
20
self .model = model
21
21
22
22
def forward (self , pixel_values ):
23
- vit_embeds = self .model .extract_feature (pixel_values )
24
- return vit_embeds
23
+ vision_embeds = self .model .extract_feature (pixel_values )
24
+ return vision_embeds
25
25
26
26
27
27
class QEffInternDecoderWrapper (nn .Module ):
@@ -31,21 +31,21 @@ def __init__(self, model):
31
31
self .config = self .model .language_model .config
32
32
self .language_model = self .model .language_model
33
33
34
- def forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
34
+ def forward (self , input_ids , vision_embeds , position_ids , past_key_values ):
35
35
input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
36
36
B , N , C = input_embeds .shape
37
37
image_input_embeds = input_embeds .reshape (B * N , C )
38
38
image_input_ids = input_ids .reshape (B * N )
39
39
selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
40
40
indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
41
41
indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
42
- image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
42
+ image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
43
43
image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
44
44
inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
45
45
outputs = self .model .language_model (
46
46
inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
47
47
)
48
- return outputs .logits , vit_embeds , outputs .past_key_values
48
+ return outputs .logits , vision_embeds , outputs .past_key_values
49
49
50
50
51
51
class QEffInternVLModel (nn .Module ):
@@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122
122
lang_dynamic_axes = {}
123
123
lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
124
124
lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
125
- lang_dynamic_axes ["vit_embeds " ] = {0 : "num_patches" }
125
+ lang_dynamic_axes ["vision_embeds " ] = {0 : "num_patches" }
126
126
vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
127
127
128
128
pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
@@ -139,15 +139,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
139
139
return dynamic_axes
140
140
141
141
def get_output_names (self , kv_offload : bool = False ):
142
- vision_output_names = ["vit_embeds " ]
142
+ vision_output_names = ["vision_embeds " ]
143
143
lang_output_names = ["logits" ]
144
144
for i in range (self .language_model .config .num_hidden_layers ):
145
145
for kv in ["key" , "value" ]:
146
146
lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
147
147
148
148
output_names = {}
149
149
if kv_offload :
150
- lang_output_names .insert (1 , "vit_embeds_RetainedState " )
150
+ lang_output_names .insert (1 , "vision_embeds_RetainedState " )
151
151
output_names ["vision" ] = vision_output_names
152
152
output_names ["lang" ] = lang_output_names
153
153
else :
@@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
175
175
# Define shapes
176
176
inputs_shapes = {}
177
177
inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
178
- inputs_shapes ["vit_embeds " ] = (
178
+ inputs_shapes ["vision_embeds " ] = (
179
179
constants .INTERN_NUM_PATCHES ,
180
180
constants .INTERN_FEATURE_SIZE ,
181
181
self .language_model .config .hidden_size ,
@@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
196
196
lang_inputs = {}
197
197
vision_inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
198
198
lang_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
199
- lang_inputs ["vit_embeds " ] = torch .zeros ((inputs_shapes ["vit_embeds " ]), dtype = torch .float32 )
199
+ lang_inputs ["vision_embeds " ] = torch .zeros ((inputs_shapes ["vision_embeds " ]), dtype = torch .float32 )
200
200
lang_inputs ["position_ids" ] = (
201
201
torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
202
202
.view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
@@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False):
220
220
inputs ["vision" ] = vision_inputs
221
221
inputs ["lang" ] = lang_inputs
222
222
else :
223
- lang_inputs .pop ("vit_embeds " )
223
+ lang_inputs .pop ("vision_embeds " )
224
224
inputs = {** vision_inputs , ** lang_inputs }
225
225
226
226
return inputs
227
227
228
228
def forward (self , input_ids , pixel_values , position_ids , past_key_values ):
229
229
input_embeds = self .language_model .get_input_embeddings ()(input_ids )
230
- vit_embeds = self .extract_feature (pixel_values )
230
+ vision_embeds = self .extract_feature (pixel_values )
231
231
B , N , C = input_embeds .shape
232
232
image_input_embeds = input_embeds .reshape (B * N , C )
233
233
image_input_ids = input_ids .reshape (B * N )
234
234
selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
235
235
indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
236
236
indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
237
- image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
237
+ image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
238
238
image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
239
239
inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
240
240
outputs = self .language_model (
0 commit comments