from __future__ import absolute_import, division, print_function, unicode_literals import tensorflow as tf ''' tf.config.optimizer.set_jit(True) gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: # Restrict TensorFlow to only use the first GPU try: tf.config.experimental.set_visible_devices(gpus[0], 'GPU') except RuntimeError as e: # Visible devices must be set at program startup print(e) ''' from tensorflow.keras.optimizers import Adam from unet_config import* import os import datetime from Unet3D import Unet3D import numpy as np import random def dice_coe(y_true,y_pred, loss_type='jaccard', smooth=1.): y_true_f = tf.reshape(y_true,[-1]) y_pred_f = tf.reshape(y_pred,[-1]) intersection = tf.reduce_sum(y_true_f * y_pred_f) if loss_type == 'jaccard': union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) elif loss_type == 'sorensen': union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) else: raise ValueError("Unknown `loss_type`: %s" % loss_type) return (2. * intersection + smooth) / (union + smooth) def dice_loss(y_true,y_pred, loss_type='jaccard', smooth=1.): y_true_f = tf.cast(tf.reshape(y_true,[-1]),tf.float32) y_pred_f =tf.cast(tf.reshape(y_pred,[-1]),tf.float32) intersection = tf.reduce_sum(y_true_f * y_pred_f) if loss_type == 'jaccard': union = tf.reduce_sum(tf.square(y_pred_f)) + tf.reduce_sum(tf.square(y_true_f)) elif loss_type == 'sorensen': union = tf.reduce_sum(y_pred_f) + tf.reduce_sum(y_true_f) else: raise ValueError("Unknown `loss_type`: %s" % loss_type) return (1-(2. * intersection + smooth) / (union + smooth)) @tf.function def decode_SEGct(Serialized_example): features={ 'image':tf.io.FixedLenFeature([],tf.string), 'mask':tf.io.FixedLenFeature([],tf.string), 'Height':tf.io.FixedLenFeature([],tf.int64), 'Weight':tf.io.FixedLenFeature([],tf.int64), 'Depth':tf.io.FixedLenFeature([],tf.int64), 'Sub_id':tf.io.FixedLenFeature([],tf.string) } examples=tf.io.parse_single_example(Serialized_example,features) ##Decode_image_float image_1 = tf.io.decode_raw(examples['image'], float) #Decode_mask_as_int32 mask_1 = tf.io.decode_raw(examples['mask'], tf.int32) ##Subject id is already in bytes format #sub_id=examples['Sub_id'] img_shape=[examples['Height'],examples['Weight'],examples['Depth']] #img_shape2=[img_shape[0],img_shape[1],img_shape[2]] print(img_shape) #Resgapping_the_data img=tf.reshape(image_1,img_shape) mask=tf.reshape(mask_1,img_shape) #Because CNN expect(batch,H,W,D,CHANNEL) img=tf.expand_dims(img, axis=-1) mask=tf.expand_dims(mask, axis=-1) ###casting_values img=tf.cast(img, tf.float32) mask=tf.cast(mask,tf.int32) return img,mask def getting_list(path): a=[file for file in os.listdir(path) if file.endswith('.tfrecords')] all_tfrecoeds=random.sample(a, len(a)) #all_tfrecoeds.sort(key=lambda f: int(filter(str.isdigit, f))) list_of_tfrecords=[] for i in range(len(all_tfrecoeds)): tf_path=path+all_tfrecoeds[i] list_of_tfrecords.append(tf_path) return list_of_tfrecords #--Traing Decoder def load_training_tfrecords(record_mask_file,batch_size): dataset=tf.data.Dataset.list_files(record_mask_file).interleave(lambda x: tf.data.TFRecordDataset(x),cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) dataset=dataset.map(decode_SEGct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) batched_dataset=dataset.prefetch(PARSHING) return batched_dataset #--Validation Decoder def load_validation_tfrecords(record_mask_file,batch_size): dataset=tf.data.Dataset.list_files(record_mask_file).interleave(tf.data.TFRecordDataset,cycle_length=NUMBER_OF_PARALLEL_CALL,num_parallel_calls=NUMBER_OF_PARALLEL_CALL) dataset=dataset.map(decode_SEGct,num_parallel_calls=NUMBER_OF_PARALLEL_CALL).repeat(TRAING_EPOCH).batch(batch_size) batched_dataset=dataset.prefetch(PARSHING) return batched_dataset def Training(): #TensorBoard logdir = os.path.join("LungSEG_Log_March30_2020", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) ##csv_logger csv_logger = tf.keras.callbacks.CSVLogger(TRAINING_CSV) ##Model-checkpoings path=TRAINING_SAVE_MODEL_PATH model_path=os.path.join(path, "LungSEGModel_{val_loss:.2f}_{epoch}.h5") Model_callback= tf.keras.callbacks.ModelCheckpoint(filepath=model_path,save_best_only=False,save_weights_only=True,monitor=ModelCheckpoint_MOTITOR,verbose=1) tf_train=getting_list(TRAINING_TF_RECORDS) tf_val=getting_list(VALIDATION_TF_RECORDS) traing_data=load_training_tfrecords(tf_train,BATCH_SIZE) Val_batched_dataset=load_validation_tfrecords(tf_val,BATCH_SIZE) if (NUM_OF_GPU==1): if RESUME_TRAINING==1: inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) Model_3D.load_weights(RESUME_TRAIING_MODEL) initial_epoch_of_training=TRAINING_INITIAL_EPOCH Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) Model_3D.summary() else: initial_epoch_of_training=0 inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) Model_3D.summary() Model_3D.fit(traing_data, steps_per_epoch=TRAINING_STEP_PER_EPOCH, epochs=TRAING_EPOCH, initial_epoch=initial_epoch_of_training, validation_data=Val_batched_dataset, validation_steps=VALIDATION_STEP, callbacks=[tensorboard_callback,csv_logger,Model_callback]) ###Multigpu---- else: mirrored_strategy = tf.distribute.MirroredStrategy(DISTRIIBUTED_STRATEGY_GPUS) with mirrored_strategy.scope(): if RESUME_TRAINING==1: inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) Model_3D.load_weights(RESUME_TRAIING_MODEL) initial_epoch_of_training=TRAINING_INITIAL_EPOCH Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) Model_3D.summary() else: initial_epoch_of_training=0 inputs = tf.keras.Input(shape=INPUT_PATCH_SIZE, name='CT') Model_3D=Unet3D(inputs,num_classes=NUMBER_OF_CLASSES) Model_3D.compile(optimizer=OPTIMIZER, loss=[dice_loss], metrics=['accuracy',dice_coe]) Model_3D.summary() Model_3D.fit(traing_data,steps_per_epoch=TRAINING_STEP_PER_EPOCH,epochs=TRAING_EPOCH,initial_epoch=initial_epoch_of_training,validation_data=Val_batched_dataset,validation_steps=VALIDATION_STEP, callbacks=[tensorboard_callback,csv_logger,Model_callback]) if __name__ == '__main__': Training()