diff --git a/cfg/TIMIT_baselines/TIMIT_MLP_fmllr.cfg b/cfg/TIMIT_baselines/TIMIT_MLP_fmllr.cfg index 8de89a14..d9dde979 100644 --- a/cfg/TIMIT_baselines/TIMIT_MLP_fmllr.cfg +++ b/cfg/TIMIT_baselines/TIMIT_MLP_fmllr.cfg @@ -11,6 +11,8 @@ use_cuda = True multi_gpu = False save_gpumem = False n_epochs_tr = 24 +# Last n_mdls_store models will be stored. Leave empty to store only the final model. +n_mdls_store = 5 [dataset1] data_name = TIMIT_tr @@ -235,4 +237,7 @@ skip_scoring = false scoring_script = local/score.sh scoring_opts = "--min-lmwt 1 --max-lmwt 10" norm_vars = False - +# Decode with model from ep_to_decode epoch. Note that epoch indexing starts from 0, +# so e.g. decoding with ep_to_decode=3 will decode with model stored after the 4th epoch. +# Leave empty to decode with the final model. +ep_to_decode = diff --git a/run_exp.py b/run_exp.py index 83684f2b..f846713c 100644 --- a/run_exp.py +++ b/run_exp.py @@ -32,7 +32,6 @@ config = configparser.ConfigParser() config.read(cfg_file) - # Reading and parsing optional arguments from command line (e.g.,--optimization,lr=0.002) [section_args,field_args,value_args]=read_args_command_line(sys.argv,config) @@ -87,17 +86,15 @@ create_lists(config) # Writing the config files -create_configs(config) +create_configs(config) print("- Chunk creation......OK!\n") # create res_file res_file_path=out_folder+'/res.res' -res_file = open(res_file_path, "w") +res_file = open(res_file_path, "a") res_file.close() - - # Learning rates and architecture-specific optimization parameters arch_lst=get_all_archs(config) lr={} @@ -144,7 +141,6 @@ lab_dict=[] arch_dict=[] - # --------TRAINING LOOP--------# for ep in range(N_ep): @@ -157,7 +153,7 @@ for tr_data in tr_data_lst: # Compute the total number of chunks for each training epoch - N_ck_tr=compute_n_chunks(out_folder,tr_data,ep,N_ep_str_format,'train') + N_ck_tr=compute_n_chunks(out_folder,tr_data,format(ep, N_ep_str_format),'train') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_tr)),1))+'d' # ***Epoch training*** @@ -180,7 +176,6 @@ # update learning rate in the cfg file (if needed) change_lr_cfg(config_chunk_file,lr,ep) - # if this chunk has not already been processed, do training... if not(os.path.exists(info_file)): @@ -210,14 +205,18 @@ for pt_arch in pt_files.keys(): pt_files[pt_arch]=out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'_'+pt_arch+'.pkl' - # remove previous pkl files + # remove previous pkl files but store last n_mdls_store models + if config['exp']['n_mdls_store']: + n_mdls_store = int(config['exp']['n_mdls_store']) + else: + n_mdls_store = 0 + if len(model_files_past.keys())>0: for pt_arch in pt_files.keys(): - if os.path.exists(model_files_past[pt_arch]): + if os.path.exists(model_files_past[pt_arch]) and (ep <= N_ep-n_mdls_store or ck != 0): os.remove(model_files_past[pt_arch]) - - # Training Loss and Error + # Training Loss and Error tr_info_lst=sorted(glob.glob(out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'*.info')) [tr_loss,tr_error,tr_time]=compute_avg_performance(tr_info_lst) @@ -238,7 +237,7 @@ for valid_data in valid_data_lst: # Compute the number of chunks for each validation dataset - N_ck_valid=compute_n_chunks(out_folder,valid_data,ep,N_ep_str_format,'valid') + N_ck_valid=compute_n_chunks(out_folder,valid_data,format(ep, N_ep_str_format),'valid') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_valid)),1))+'d' for ck in range(N_ck_valid): @@ -276,11 +275,10 @@ valid_peformance_dict[valid_data]=[valid_loss,valid_error,valid_time] tot_time=tot_time+valid_time - - # Print results in both res_file and stdout - dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot, tot_time, valid_data_lst, valid_peformance_dict, lr, N_ep) + # Print results in both res_file and stdout, do not overwrite res.res file when reruning decoding + if not(os.path.exists(out_folder+'/exp_files/final_'+pt_arch+'.pkl')): + dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot, tot_time, valid_data_lst, valid_peformance_dict, lr, N_ep) - # Check for learning rate annealing if ep>0: # computing average validation error (on all the dataset specified) @@ -302,21 +300,26 @@ # --------FORWARD--------# for forward_data in forward_data_lst: + if config['decoding']['ep_to_decode']: + decode_epoch = config['decoding']['ep_to_decode'] + else: + decode_epoch = format(ep, N_ep_str_format) + # Compute the number of chunks - N_ck_forward=compute_n_chunks(out_folder,forward_data,ep,N_ep_str_format,'forward') + N_ck_forward=compute_n_chunks(out_folder,forward_data,decode_epoch,'forward') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_forward)),1))+'d' for ck in range(N_ck_forward): + if not is_production: - print('Testing %s chunk = %i / %i' %(forward_data,ck+1, N_ck_forward)) + print('Testing %s chunk = %i / %i with model stored in epoch %s' %(forward_data,ck+1, N_ck_forward, decode_epoch)) else: - print('Forwarding %s chunk = %i / %i' %(forward_data,ck+1, N_ck_forward)) + print('Forwarding %s chunk = %i / %i with model stored in epoch %s' %(forward_data,ck+1, N_ck_forward, decode_epoch)) # output file - info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.info' - config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.cfg' - + info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format) + '.info' + config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format) + '.cfg' # Do forward if the chunk was not already processed if not(os.path.exists(info_file)): @@ -329,15 +332,13 @@ # run chunk processing [data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file) - # update the first_processed variable processed_first=False if not(os.path.exists(info_file)): sys.stderr.write("ERROR: forward chunk %i of dataset %s not done! File %s does not exist.\nSee %s \n" % (ck,forward_data,info_file,log_file)) sys.exit(0) - - + # update the operation counter op_counter+=1 @@ -350,18 +351,15 @@ forward_outs=config['forward']['forward_out'].split(',') forward_dec_outs=list(map(strtobool,config['forward']['require_decoding'].split(','))) - for data in forward_data_lst: for k in range(len(forward_outs)): if forward_dec_outs[k]: - - print('Decoding %s output %s' %(data,forward_outs[k])) - - info_file=out_folder+'/exp_files/decoding_'+data+'_'+forward_outs[k]+'.info' - - + + print('Decoding %s output %s for model stored in epoch %s' %(data,forward_outs[k],decode_epoch)) + + info_file=out_folder + '/exp_files/decoding_' + data + '_' + forward_outs[k] + '_e' + decode_epoch + '.info' # create decode config file - config_dec_file=out_folder+'/decoding_'+data+'_'+forward_outs[k]+'.conf' + config_dec_file=out_folder + '/decoding_' + data + '_' + forward_outs[k] + '_e' + decode_epoch + '.conf' config_dec = configparser.ConfigParser() config_dec.add_section('decoding') @@ -402,14 +400,18 @@ out_folder=os.path.abspath(out_folder) files_dec=out_folder+'/exp_files/forward_'+data+'_ep*_ck*_'+forward_outs[k]+'_to_decode.ark' - out_dec_folder=out_folder+'/decode_'+data+'_'+forward_outs[k] + out_dec_folder=out_folder+'/decode_' + data + '_' + forward_outs[k] + '_e' + decode_epoch if not(os.path.exists(info_file)): # Run the decoder cmd_decode=cmd+config['decoding']['decoding_script_folder'] +'/'+ config['decoding']['decoding_script']+ ' '+os.path.abspath(config_dec_file)+' '+ out_dec_folder + ' \"'+ files_dec + '\"' run_shell(cmd_decode,log_file) - + + # Create deocding info file + with open(info_file, 'a'): + os.utime(info_file, None) + # remove ark files if needed if not forward_save_files[k]: list_rem=glob.glob(files_dec) diff --git a/utils.py b/utils.py index 4729a386..07e3cb2d 100644 --- a/utils.py +++ b/utils.py @@ -21,7 +21,6 @@ import math - def run_command(cmd): """from http://blog.kagesenshi.org/2008/02/teeing-python-subprocesspopen-output.html """ @@ -674,7 +673,7 @@ def split_chunks(seq, size): def create_configs(config): - # This function create the chunk-specific config files + # This function create the chunk-specific config files cfg_file_proto_chunk=config['cfg_proto']['cfg_proto_chunk'] N_ep=int(config['exp']['N_epochs_tr']) N_ep_str_format='0'+str(max(math.ceil(np.log10(N_ep)),1))+'d' @@ -692,30 +691,28 @@ def create_configs(config): # Read the batch size string batch_size_tr_str=config['batches']['batch_size_train'] - batch_size_tr_arr=expand_str_ep(batch_size_tr_str,'int',N_ep,'|','*') - + batch_size_tr_arr=expand_str_ep(batch_size_tr_str,'int',N_ep,'|','*') + # Read the max_seq_length_train max_seq_length_tr_arr=expand_str_ep(max_seq_length_train,'int',N_ep,'|','*') - cfg_file_proto=config['cfg_proto']['cfg_proto'] [config,name_data,name_arch]=check_cfg(cfg_file,config,cfg_file_proto) - arch_lst=get_all_archs(config) lr={} improvement_threshold={} halving_factor={} pt_files={} drop_rates={} + for arch in arch_lst: lr_arr=expand_str_ep(config[arch]['arch_lr'],'float',N_ep,'|','*') lr[arch]=lr_arr - improvement_threshold[arch]=float(config[arch]['arch_improvement_threshold']) halving_factor[arch]=float(config[arch]['arch_halving_factor']) pt_files[arch]=config[arch]['arch_pretrain_file'] - + # Loop over all the sections and look for a "_drop" field (to perform dropout scheduling for (field_key, field_val) in config.items(arch): if "_drop" in field_key: @@ -724,18 +721,15 @@ def create_configs(config): drop_rates[arch]=[] for lay_id in range(N_lay): drop_rates[arch].append(expand_str_ep(drop_lay[lay_id],'float',N_ep,'|','*')) - + # Check dropout factors for dropout_factor in drop_rates[arch][0]: if float(dropout_factor)<0.0 or float(dropout_factor)>1.0: sys.stderr.write('The dropout rate should be between 0 and 1. Got %s in %s.\n' %(dropout_factor,field_key)) sys.exit(0) - - if strtobool(config['batches']['increase_seq_length_train']): max_seq_length_train_curr=int(config['batches']['start_seq_len_train']) - for ep in range(N_ep): @@ -743,7 +737,7 @@ def create_configs(config): for tr_data in tr_data_lst: # Compute the total number of chunks for each training epoch - N_ck_tr=compute_n_chunks(out_folder,tr_data,ep,N_ep_str_format,'train') + N_ck_tr=compute_n_chunks(out_folder,tr_data,format(ep, N_ep_str_format),'train') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_tr)),1))+'d' # ***Epoch training*** @@ -766,10 +760,10 @@ def create_configs(config): config_chunk_file=out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.cfg' lst_chunk_file.write(config_chunk_file+'\n') - + if strtobool(config['batches']['increase_seq_length_train'])==False: max_seq_length_train_curr=int(max_seq_length_tr_arr[ep]) - + # Write chunk-specific cfg file write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst_file,info_file,'train',tr_data,lr,max_seq_length_train_curr,name_data,ep,ck,batch_size_tr_arr[ep],drop_rates) @@ -780,7 +774,7 @@ def create_configs(config): for valid_data in valid_data_lst: # Compute the number of chunks for each validation dataset - N_ck_valid=compute_n_chunks(out_folder,valid_data,ep,N_ep_str_format,'valid') + N_ck_valid=compute_n_chunks(out_folder,valid_data,format(ep, N_ep_str_format),'valid') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_valid)),1))+'d' for ck in range(N_ck_valid): @@ -800,36 +794,50 @@ def create_configs(config): max_seq_length_train_curr=max_seq_length_train_curr*int(config['batches']['multply_factor_seq_len_train']) if max_seq_length_train_curr>int(max_seq_length_tr_arr[ep]): max_seq_length_train_curr=int(max_seq_length_tr_arr[ep]) - for forward_data in forward_data_lst: + if config['decoding']['ep_to_decode']: + decode_epoch = config['decoding']['ep_to_decode'] + else: + decode_epoch = format(ep, N_ep_str_format) + # Compute the number of chunks - N_ck_forward=compute_n_chunks(out_folder,forward_data,ep,N_ep_str_format,'forward') + N_ck_forward=compute_n_chunks(out_folder,forward_data,decode_epoch,'forward') N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_forward)),1))+'d' for ck in range(N_ck_forward): - + # path of the list of features for this chunk - lst_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'_*.lst' + lst_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format)+'_*.lst' # output file - info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.info' - config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.cfg' + info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format) +'.info' + config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format)+'.cfg' lst_chunk_file.write(config_chunk_file+'\n') + + # Alter the model that the data will be forwarded with according to ep_to_decode + if config['decoding']['ep_to_decode']: + if int(decode_epoch) >= ep - int(config['exp']['n_mdls_store']): + for arch in pt_files.keys(): + model_files[arch] = re.sub("_ep\d+_", "_ep"+ decode_epoch + "_", model_files[arch]) + else: + print("You didn't store enough models to decode with model from epoch ", \ + config['decoding']['ep_to_decode'], ". Change 'ep_to_decode' to match \ + a stored model, or store more models ('n_mdls_store').") + sys.exit() # Write chunk-specific cfg file write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,model_files,lst_file,info_file,'forward',forward_data,lr,max_seq_length_train_curr,name_data,ep,ck,batch_size_tr_arr[ep],drop_rates) lst_chunk_file.close() - def create_lists(config): # splitting data into chunks (see out_folder/additional_files) out_folder=config['exp']['out_folder'] seed=int(config['exp']['seed']) - N_ep=int(config['exp']['N_epochs_tr']) + N_ep=int(config['exp']['N_epochs_tr']) N_ep_str_format='0'+str(max(math.ceil(np.log10(N_ep)),1))+'d' # Setting the random seed @@ -960,15 +968,18 @@ def create_lists(config): for snt in forward_chunks_fea[ck]: #print(snt.split(',')[i]) forward_chunks_fea_split.append(snt.split(',')[i]) + + if config['decoding']['ep_to_decode']: + decode_epoch = config['decoding']['ep_to_decode'] + else: + decode_epoch = format(ep, N_ep_str_format) - output_lst_file=out_folder+'/exp_files/forward_'+dataset+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck,N_ck_str_format)+'_'+fea_names[i]+'.lst' + output_lst_file=out_folder+'/exp_files/forward_'+dataset+'_ep'+ decode_epoch +'_ck'+format(ck,N_ck_str_format)+'_'+fea_names[i]+'.lst' f=open(output_lst_file,'w') forward_chunks_fea_wr=map(lambda x:x+'\n', forward_chunks_fea_split) f.writelines(forward_chunks_fea_wr) - f.close() + f.close() - - def write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst_file,info_file,to_do,data_set_name,lr,max_seq_length_train_curr,name_data,ep,ck,batch_size,drop_rates): # writing the chunk-specific cfg file @@ -984,7 +995,7 @@ def write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst # change seed for randomness config_chunk['exp']['seed']=str(int(config_chunk['exp']['seed'])+ep+ck) - + config_chunk['batches']['batch_size_train']=batch_size for arch in pt_files.keys(): @@ -993,17 +1004,16 @@ def write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst # writing the current learning rate for lr_arch in lr.keys(): config_chunk[lr_arch]['arch_lr']=str(lr[lr_arch][ep]) - + for (field_key, field_val) in config.items(lr_arch): if "_drop" in field_key: N_lay=len(drop_rates[lr_arch]) drop_arr=[] for lay in range(N_lay): drop_arr.append(drop_rates[lr_arch][lay][ep]) - + config_chunk[lr_arch][field_key]=str(','.join(drop_arr)) - # Data_chunk section config_chunk.add_section('data_chunk') @@ -1044,11 +1054,10 @@ def write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst # Write cfg_file_chunk with open(config_chunk_file, 'w') as configfile: config_chunk.write(configfile) - + # Check cfg_file_chunk [config_proto_chunk,name_data_ck,name_arch_ck]=check_consistency_with_proto(config_chunk_file,cfg_file_proto_chunk) - def parse_fea_field(fea): # Adding the required fields into a list @@ -1147,9 +1156,8 @@ def parse_lab_field(lab): return [lab_names,lab_folders,lab_opts] - -def compute_n_chunks(out_folder,data_list,ep,N_ep_str_format,step): - list_ck=sorted(glob.glob(out_folder+'/exp_files/'+step+'_'+data_list+'_ep'+format(ep, N_ep_str_format)+'*.lst')) +def compute_n_chunks(out_folder,data_list,epoch,step): + list_ck=sorted(glob.glob(out_folder+'/exp_files/'+step+'_'+data_list+'_ep'+ epoch +'*.lst')) last_ck=list_ck[-1] N_ck=int(re.findall('_ck(.+)_', last_ck)[-1].split('_')[0])+1 return N_ck @@ -1175,8 +1183,6 @@ def parse_model_field(cfg_file): arch_lst=list(re.findall('arch_name=(.*)\n',open(cfg_file, 'r').read().replace(' ',''))) possible_operations=re.findall('(.*)\((.*),(.*)\)\n',proto_model) - - possible_inputs=fea_lst model_arch=list(filter(None, model.replace(' ','').split('\n'))) @@ -1655,15 +1661,12 @@ def model_init(inp_out_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do): # add use cuda and todo options config.set(arch_dict[inp1][0],'use_cuda',config['exp']['use_cuda']) config.set(arch_dict[inp1][0],'to_do',config['exp']['to_do']) - + arch_freeze_flag=strtobool(config[arch_dict[inp1][0]]['arch_freeze']) - # initialize the neural network net=nn_class(config[arch_dict[inp1][0]],inp_dim) - - - + if use_cuda: net.cuda() if multi_gpu: @@ -1701,11 +1704,10 @@ def model_init(inp_out_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do): if operation=='cost_nll': costs[out_name] = nn.NLLLoss() inp_out_dict[out_name]=[1] - - + if operation=='cost_err': inp_out_dict[out_name]=[1] - + if operation=='mult' or operation=='sum' or operation=='mult_constant' or operation=='sum_constant' or operation=='avg' or operation=='mse': inp_out_dict[out_name]=inp_out_dict[inp1] @@ -1832,9 +1834,9 @@ def forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,m lab_dnn=inp[:,:,lab_dict[inp2][3]] if len(inp.shape)==2: lab_dnn=inp[:,lab_dict[inp2][3]] - + lab_dnn=lab_dnn.view(-1).long() - + # put output in the right format out=outs_dict[inp1] @@ -1867,7 +1869,6 @@ def forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,m outs_dict[out_name]=err #print(err) - if operation=='concatenate': dim_conc=len(outs_dict[inp1].shape)-1 outs_dict[out_name]=torch.cat((outs_dict[inp1],outs_dict[inp2]),dim_conc) #check concat axis @@ -1913,7 +1914,6 @@ def dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot # # Default terminal line size is 80 characters, try new dispositions to fit this limit # - N_ep_str_format='0'+str(max(math.ceil(np.log10(N_ep)),1))+'d' res_file = open(res_file_path, "a") res_file.write('ep=%s tr=%s loss=%s err=%s ' %(format(ep, N_ep_str_format),tr_data_lst,format(tr_loss_tot/len(tr_data_lst), "0.3f"),format(tr_error_tot/len(tr_data_lst), "0.3f"))) @@ -1926,7 +1926,6 @@ def dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot res_file.write('valid=%s loss=%s err=%s ' %(valid_data,format(valid_peformance_dict[valid_data][0], "0.3f"),format(valid_peformance_dict[valid_data][1], "0.3f"))) print('Validating on %s' %(valid_data)) print('Loss = %s | err = %s '%(format(valid_peformance_dict[valid_data][0], "0.3f"),format(valid_peformance_dict[valid_data][1], "0.3f"))) - print('-----') for lr_arch in lr.keys(): res_file.write('lr_%s=%s ' %(lr_arch,lr[lr_arch][ep])) @@ -2097,43 +2096,42 @@ def change_lr_cfg(cfg_file,lr,ep): field='arch_lr' for lr_arch in lr.keys(): - config.set(lr_arch,field,str(lr[lr_arch][ep])) # Write cfg_file_chunk with open(cfg_file, 'w') as configfile: config.write(configfile) - + def shift(arr, num, fill_value=np.nan): if num >= 0: return np.concatenate((np.full(num, fill_value), arr[:-num])) else: return np.concatenate((arr[-num:], np.full(-num, fill_value))) - + def expand_str_ep(str_compact,type_inp,N_ep,split_elem,mult_elem): - + lst_out=[] - + str_compact_lst=str_compact.split(split_elem) - + for elem in str_compact_lst: elements=elem.split(mult_elem) - + if type_inp=='int': try: int(elements[0]) except ValueError: sys.stderr.write('The string "%s" must contain integers. Got %s.\n' %(str_compact,elements[0])) sys.exit(0) - + if type_inp=='float': try: float(elements[0]) except ValueError: sys.stderr.write('The string "%s" must contain floats. Got %s.\n' %(str_compact,elements[0])) sys.exit(0) - + if len(elements)==2: try: int(elements[1]) @@ -2141,19 +2139,17 @@ def expand_str_ep(str_compact,type_inp,N_ep,split_elem,mult_elem): except ValueError: sys.stderr.write('The string "%s" must contain integers. Got %s\n' %(str_compact,elements[1])) sys.exit(0) - + if len(elements)==1: lst_out.append(elements[0]) - + if len(str_compact_lst)==1 and len(elements)==1: lst_out.extend([elements[0] for i in range(N_ep-1)]) - - + + # Final check if len(lst_out)!=N_ep: sys.stderr.write('The total number of elements specified in the string "%s" is equal to %i not equal to the total number of epochs %s.\n' %(str_compact,len(lst_out),N_ep)) sys.exit(0) - + return lst_out - -