@@ -799,33 +799,34 @@ def generate(self, prompt, params):
799
799
answer = self .tokenizer .decode (sample [0 ][len (inputs [0 ]):], skip_special_tokens = True )
800
800
return answer , self .info
801
801
802
- def interview_run (runtime , generate , interview , params_json , output_template , batch = False ):
802
+ def interview_run (runtime , generate , interview , params_json , output_template , batch = False , quiet = False ):
803
803
if batch :
804
- print (f"Running batch of { len (interview )} prompts" )
804
+ if not quiet : print (f"Running batch of { len (interview )} prompts" )
805
805
prompts = [q ['prompt' ] for q in interview ]
806
806
answers , model_info = generate (prompts , params = params_json )
807
- print ('Local model info:' , model_info )
807
+ if not quiet : print ('Local model info:' , model_info )
808
808
else :
809
809
answers = []
810
810
model_info = None
811
811
for idx , question in enumerate (interview ):
812
- print (f"{ idx + 1 } /{ len (interview )} { question ['name' ]} { question ['language' ]} " )
812
+ if not quiet : print (f"{ idx + 1 } /{ len (interview )} { question ['name' ]} { question ['language' ]} " )
813
813
814
814
# generate the answer
815
815
result , info = generate (question ['prompt' ], params = params_json )
816
816
817
817
# save for later
818
818
if model_info is None :
819
819
model_info = info
820
- print ('Local model info:' , model_info )
820
+ if not quiet : print ('Local model info:' , model_info )
821
821
822
822
# optional output template
823
823
answer = output_template .render (** question , Answer = result ) if output_template else result
824
824
answers .append (answer )
825
825
826
- print ()
827
- print (answer )
828
- print ()
826
+ if not quiet :
827
+ print ()
828
+ print (answer )
829
+ print ()
829
830
830
831
results = []
831
832
for idx , question in enumerate (interview ):
@@ -877,11 +878,74 @@ def download_safetensors(model_name, revision=None):
877
878
print ('Download problem: ' , e )
878
879
continue
879
880
break
881
+
882
+ #################################
883
+ ## Mistral-Inference Adapter ##
884
+ #################################
885
+ def is_torchrun () -> bool :
886
+ required_vars = ["MASTER_ADDR" , "MASTER_PORT" , "RANK" , "WORLD_SIZE" ]
887
+ return all (var in os .environ for var in required_vars )
888
+
889
+ class InterviewMistral :
890
+ def __init__ (self , model_name , model_info = {}, gpu_split = None , token_healing = False , cache_8bit = False ):
891
+ self .model_name = model_name
892
+ self .gpu_split = gpu_split
893
+ self .info = model_info
894
+
895
+ self .tokenizer = None
896
+ self .model = None
897
+ self .batch = False
898
+
899
+ self .info ['model_name' ] = self .model_name .split ('/' )[- 1 ]
880
900
901
+ def load (self ):
902
+ print ("Starting load.." )
903
+ config_path = self .model_name # hf_hub_download(repo_id=self.model_name, revision=self.info.get('revision',None), filename="params.json")
904
+
905
+ from mistral_inference .model import Transformer
906
+ from mistral_common .tokens .tokenizers .mistral import MistralTokenizer
907
+ import torch
908
+
909
+ if is_torchrun ():
910
+ torch .distributed .init_process_group ()
911
+ torch .cuda .set_device (torch .distributed .get_rank ())
912
+ num_pipeline_ranks = torch .distributed .get_world_size ()
913
+ else :
914
+ num_pipeline_ranks = 1
915
+
916
+ print ("Loading model..." )
917
+ dtype = torch .bfloat16 if self .info .get ('bf16' , False ) else torch .float16
918
+ self .tokenizer = MistralTokenizer .from_file (f"{ config_path } /tokenizer.model.v3" )
919
+ self .model = Transformer .from_folder (config_path , num_pipeline_ranks = num_pipeline_ranks , dtype = dtype )
920
+
921
+ def generate (self , prompt , params ):
922
+ from mistral_inference .generate import generate
923
+ from mistral_common .protocol .instruct .messages import UserMessage
924
+ from mistral_common .protocol .instruct .request import ChatCompletionRequest
925
+
926
+ eos_id = self .tokenizer .instruct_tokenizer .tokenizer .eos_id
927
+ if self .info .get ('eos_token_id' ):
928
+ eos_id = self .info .get ('eos_token_id' )
929
+ # print("overide stop_token:", eos_id)
930
+
931
+ self .info ['sampling_params' ] = {
932
+ 'max_tokens' : params .get ('max_new_tokens' ),
933
+ 'temperature' : params .get ('temperature' , 0.0 )
934
+ }
935
+
936
+ completion_request = ChatCompletionRequest (messages = [UserMessage (content = prompt )])
937
+ tokens = self .tokenizer .encode_chat_completion (completion_request ).tokens
938
+ out_tokens , _ = generate ([tokens ], self .model , eos_id = eos_id , ** self .info ['sampling_params' ])
939
+ result = self .tokenizer .instruct_tokenizer .tokenizer .decode (out_tokens [0 ])
940
+
941
+ return result , self .info
942
+
943
+
881
944
def main (input : str , params : str , model_name : str , runtime : str , info : str = "{}" , iterations : int = 1 , quant : str = "" , gpusplit : str = "" , templateout : str = "" , revision : str = "" , stop :str = "" , completion : bool = False ):
882
945
from prepare import save_interview
883
946
884
- download_safetensors (model_name , revision if revision else None )
947
+ if runtime != 'mistral' :
948
+ download_safetensors (model_name , revision if revision else None )
885
949
886
950
gpu_split = gpusplit if gpusplit != '' else None
887
951
model_info = json .loads (info ) if isinstance (info , str ) else info
@@ -894,8 +958,8 @@ def main(input: str, params: str, model_name: str, runtime: str, info: str = "{}
894
958
ga ['stop_seq' ] += ["\n #" ,"\n //" ,"\n \n \n \n " ]
895
959
if stop != '' :
896
960
ga ['stop_seq' ] += stop
897
- model_info ['generate_args' ] = ga
898
-
961
+ model_info ['generate_args' ] = ga
962
+
899
963
if runtime == 'transformers' :
900
964
if quant :
901
965
quant_id = None
@@ -924,6 +988,8 @@ def main(input: str, params: str, model_name: str, runtime: str, info: str = "{}
924
988
model = InterviewHQQ (model_name , model_info , gpu_split = gpu_split )
925
989
elif runtime == 'ctranslate2' :
926
990
model = InterviewCtranslate2 (model_name , model_info , gpu_split = gpu_split )
991
+ elif runtime == 'mistral' :
992
+ model = InterviewMistral (model_name , model_info )
927
993
else :
928
994
raise Exception ('Unknown runtime ' + runtime )
929
995
@@ -934,6 +1000,11 @@ def main(input: str, params: str, model_name: str, runtime: str, info: str = "{}
934
1000
for input_file in input .split (',' ):
935
1001
tasks .append ((param_file , input_file ))
936
1002
1003
+ should_save = True
1004
+ if is_torchrun ():
1005
+ import torch
1006
+ should_save = torch .distributed .get_rank () == 0
1007
+
937
1008
for param_file , input_pairs in tasks :
938
1009
insplit = input_pairs .split (':' )
939
1010
input_file = insplit [0 ]
@@ -944,10 +1015,12 @@ def main(input: str, params: str, model_name: str, runtime: str, info: str = "{}
944
1015
params_json = json .load (open (param_file ,'r' ))
945
1016
946
1017
for iter in range (iterations ):
947
- print ("Starting" , model_name , "iter=" , iter , "param_file=" , param_file , "input_file=" , input_file , "templateout_file=" , templateout_file )
948
- results , remote_info = interview_run (runtime , model .generate , interview , params_json , output_template , batch = model .batch )
949
- save_interview (input_file , templateout_file if templateout_file else 'none' , param_file , remote_info ['model_name' ], results )
950
-
1018
+ if should_save :
1019
+ print ("Starting" , model_name , "iter=" , iter , "param_file=" , param_file , "input_file=" , input_file , "templateout_file=" , templateout_file )
1020
+ results , remote_info = interview_run (runtime , model .generate , interview , params_json , output_template , batch = model .batch , quiet = not should_save )
1021
+ if should_save :
1022
+ save_interview (input_file , templateout_file if templateout_file else 'none' , param_file , remote_info ['model_name' ], results )
1023
+
951
1024
if __name__ == "__main__" :
952
1025
import fire
953
1026
fire .Fire (main )
0 commit comments