diff --git a/language/llama2-70b/ci_file/qllama2_70b_forward_test.py b/language/llama2-70b/ci_file/qllama2_70b_forward_test.py index 15d368c444..8581976cf1 100644 --- a/language/llama2-70b/ci_file/qllama2_70b_forward_test.py +++ b/language/llama2-70b/ci_file/qllama2_70b_forward_test.py @@ -7,6 +7,7 @@ import furiosa_llm_models import joblib import model_compressor +import model_compressor_impl import torch import yaml from torch.nn.functional import pad @@ -114,7 +115,7 @@ def obtain_traced_model_dict(model): prefill_model, prefill_input_names, prefill_concrete_args, - ) = model_compressor.helper.llama_custom_symbolic_trace( + ) = model_compressor_impl.helper.llama_custom_symbolic_trace( model, input_names=["input_ids", "attention_mask", "position_ids"], disable_check=True, @@ -123,7 +124,7 @@ def obtain_traced_model_dict(model): decode_model, decode_input_names, decode_concrete_args, - ) = model_compressor.helper.llama_custom_symbolic_trace( + ) = model_compressor_impl.helper.llama_custom_symbolic_trace( model, input_names=[ "input_ids", @@ -228,7 +229,7 @@ def get_generator_for_golden_model( "decode_model": quant_golden_models["decode"], } - return model_compressor.helper.QuantCausalLM( + return model_compressor_impl.helper.QuantCausalLM( quant_golden_models, golden_model_type, golden_input_names, golden_concrete_args ) @@ -315,7 +316,7 @@ def perform_generation( generation_output_dictionary = dict() with torch.no_grad(): for idx, test_data in enumerate(test_data_list): - if type(generator) == model_compressor.helper.QuantCausalLM: + if type(generator) == model_compressor_impl.helper.QuantCausalLM: output = generator.generate(**test_data, **gen_kwargs) elif type(generator) == MLPerfSubmissionGreedySearch: # mlperf submission