From 0d5438bacf70c807a1baea9fb43e758ddfa2da71 Mon Sep 17 00:00:00 2001 From: Guy Leroy Date: Sun, 26 Apr 2020 10:07:02 +0100 Subject: [PATCH 1/3] Add if __name__ == '__main__' in examples and fix typo in notebook --- .../densenet_evaluation_array.py | 107 +++---- .../densenet_evaluation_dict.py | 109 +++---- .../densenet_training_array.py | 229 +++++++------- .../densenet_training_dict.py | 229 +++++++------- .../densenet_evaluation_array.py | 117 ++++---- .../densenet_evaluation_dict.py | 121 ++++---- .../densenet_training_array.py | 257 ++++++++-------- .../densenet_training_dict.py | 275 ++++++++--------- .../unet_segmentation_3d_ignite.ipynb | 2 +- .../segmentation_3d/unet_evaluation_array.py | 97 +++--- .../segmentation_3d/unet_evaluation_dict.py | 111 +++---- .../segmentation_3d/unet_training_array.py | 243 +++++++-------- .../segmentation_3d/unet_training_dict.py | 241 +++++++-------- .../unet_evaluation_array.py | 147 ++++----- .../unet_evaluation_dict.py | 155 +++++----- .../unet_training_array.py | 273 ++++++++--------- .../unet_training_dict.py | 281 +++++++++--------- 17 files changed, 1505 insertions(+), 1489 deletions(-) diff --git a/examples/classification_3d/densenet_evaluation_array.py b/examples/classification_3d/densenet_evaluation_array.py index 091108d03a..af6547a785 100644 --- a/examples/classification_3d/densenet_evaluation_array.py +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -19,61 +19,62 @@ from monai.data import NiftiDataset, CSVSaver from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) -# Define transforms for image -val_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) + # Define transforms for image + val_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) -# Define nifti dataset -val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) -# create a validation data loader -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + # Define nifti dataset + val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) + # create a validation data loader + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -# Create DenseNet121 -device = torch.device('cuda:0') -model = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -).to(device) + # Create DenseNet121 + device = torch.device('cuda:0') + model = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ).to(device) -model.load_state_dict(torch.load('best_metric_model.pth')) -model.eval() -with torch.no_grad(): - num_correct = 0. - metric_count = 0 - saver = CSVSaver(output_dir='./output') - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - val_outputs = model(val_images).argmax(dim=1) - value = torch.eq(val_outputs, val_labels) - metric_count += len(value) - num_correct += value.sum().item() - saver.save_batch(val_outputs, val_data[2]) - metric = num_correct / metric_count - print('evaluation metric:', metric) - saver.finalize() + model.load_state_dict(torch.load('best_metric_model.pth')) + model.eval() + with torch.no_grad(): + num_correct = 0. + metric_count = 0 + saver = CSVSaver(output_dir='./output') + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + val_outputs = model(val_images).argmax(dim=1) + value = torch.eq(val_outputs, val_labels) + metric_count += len(value) + num_correct += value.sum().item() + saver.save_batch(val_outputs, val_data[2]) + metric = num_correct / metric_count + print('evaluation metric:', metric) + saver.finalize() diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py index 93ecb535dc..ba6f4d9b76 100644 --- a/examples/classification_3d/densenet_evaluation_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -19,62 +19,63 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, ToTensord from monai.data import CSVSaver -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) -val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) + val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] -# Define transforms for image -val_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - ToTensord(keys=['img']) -]) + # Define transforms for image + val_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + ToTensord(keys=['img']) + ]) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -# Create DenseNet121 -device = torch.device('cuda:0') -model = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -).to(device) + # Create DenseNet121 + device = torch.device('cuda:0') + model = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ).to(device) -model.load_state_dict(torch.load('best_metric_model.pth')) -model.eval() -with torch.no_grad(): - num_correct = 0. - metric_count = 0 - saver = CSVSaver(output_dir='./output') - for val_data in val_loader: - val_images, val_labels = val_data['img'].to(device), val_data['label'].to(device) - val_outputs = model(val_images).argmax(dim=1) - value = torch.eq(val_outputs, val_labels) - metric_count += len(value) - num_correct += value.sum().item() - saver.save_batch(val_outputs, {'filename_or_obj': val_data['img.filename_or_obj']}) - metric = num_correct / metric_count - print('evaluation metric:', metric) - saver.finalize() + model.load_state_dict(torch.load('best_metric_model.pth')) + model.eval() + with torch.no_grad(): + num_correct = 0. + metric_count = 0 + saver = CSVSaver(output_dir='./output') + for val_data in val_loader: + val_images, val_labels = val_data['img'].to(device), val_data['label'].to(device) + val_outputs = model(val_images).argmax(dim=1) + value = torch.eq(val_outputs, val_labels) + metric_count += len(value) + num_correct += value.sum().item() + saver.save_batch(val_outputs, {'filename_or_obj': val_data['img.filename_or_obj']}) + metric = num_correct / metric_count + print('evaluation metric:', metric) + saver.finalize() diff --git a/examples/classification_3d/densenet_training_array.py b/examples/classification_3d/densenet_training_array.py index dbaab72d01..e3c42a4288 100644 --- a/examples/classification_3d/densenet_training_array.py +++ b/examples/classification_3d/densenet_training_array.py @@ -20,125 +20,126 @@ from monai.data import NiftiDataset from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, RandRotate90, ToTensor -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) -# Define transforms -train_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - RandRotate90(), - ToTensor() -]) -val_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) + # Define transforms + train_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + RandRotate90(), + ToTensor() + ]) + val_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) -# Define nifti dataset, data loader -check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) -check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(check_loader) -print(type(im), im.shape, label) + # Define nifti dataset, data loader + check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) + check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + im, label = monai.utils.misc.first(check_loader) + print(type(im), im.shape, label) -# create a training data loader -train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + # create a training data loader + train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -# Create DenseNet121, CrossEntropyLoss and Adam optimizer -device = torch.device('cuda:0') -model = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -).to(device) -loss_function = torch.nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), 1e-5) + # Create DenseNet121, CrossEntropyLoss and Adam optimizer + device = torch.device('cuda:0') + model = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ).to(device) + loss_function = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), 1e-5) -# start a typical PyTorch training -val_interval = 2 -best_metric = -1 -best_metric_epoch = -1 -epoch_loss_values = list() -metric_values = list() -writer = SummaryWriter() -for epoch in range(5): - print('-' * 10) - print('epoch {}/{}'.format(epoch + 1, 5)) - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data[0].to(device), batch_data[1].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) - writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + metric_values = list() + writer = SummaryWriter() + for epoch in range(5): + print('-' * 10) + print('epoch {}/{}'.format(epoch + 1, 5)) + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data[0].to(device), batch_data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) + writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - num_correct = 0. - metric_count = 0 - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - val_outputs = model(val_images) - value = torch.eq(val_outputs.argmax(dim=1), val_labels) - metric_count += len(value) - num_correct += value.sum().item() - metric = num_correct / metric_count - metric_values.append(metric) - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), 'best_metric_model.pth') - print('saved new best metric model') - print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format( - epoch + 1, metric, best_metric, best_metric_epoch)) - writer.add_scalar('val_accuracy', metric, epoch + 1) -print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) -writer.close() + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + num_correct = 0. + metric_count = 0 + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + val_outputs = model(val_images) + value = torch.eq(val_outputs.argmax(dim=1), val_labels) + metric_count += len(value) + num_correct += value.sum().item() + metric = num_correct / metric_count + metric_values.append(metric) + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), 'best_metric_model.pth') + print('saved new best metric model') + print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format( + epoch + 1, metric, best_metric, best_metric_epoch)) + writer.add_scalar('val_accuracy', metric, epoch + 1) + print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) + writer.close() diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py index 82c01e56a5..c69d8da23e 100644 --- a/examples/classification_3d/densenet_training_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -20,126 +20,127 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, RandRotate90d, ToTensord from monai.metrics import compute_roc_auc -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) -train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] -val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) + train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] + val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] -# Define transforms for image -train_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]), - ToTensord(keys=['img']) -]) -val_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - ToTensord(keys=['img']) -]) + # Define transforms for image + train_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]), + ToTensord(keys=['img']) + ]) + val_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + ToTensord(keys=['img']) + ]) -# Define dataset, data loader -check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(check_loader) -print(check_data['img'].shape, check_data['label']) + # Define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + check_data = monai.utils.misc.first(check_loader) + print(check_data['img'].shape, check_data['label']) -# create a training data loader -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -# Create DenseNet121, CrossEntropyLoss and Adam optimizer -device = torch.device('cuda:0') -model = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -).to(device) -loss_function = torch.nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), 1e-5) + # Create DenseNet121, CrossEntropyLoss and Adam optimizer + device = torch.device('cuda:0') + model = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ).to(device) + loss_function = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), 1e-5) -# start a typical PyTorch training -val_interval = 2 -best_metric = -1 -best_metric_epoch = -1 -writer = SummaryWriter() -for epoch in range(5): - print('-' * 10) - print('epoch {}/{}'.format(epoch + 1, 5)) - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data['img'].to(device), batch_data['label'].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) - writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) - epoch_loss /= step - print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + writer = SummaryWriter() + for epoch in range(5): + print('-' * 10) + print('epoch {}/{}'.format(epoch + 1, 5)) + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data['img'].to(device), batch_data['label'].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) + writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) + epoch_loss /= step + print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - y_pred = torch.tensor([], dtype=torch.float32, device=device) - y = torch.tensor([], dtype=torch.long, device=device) - for val_data in val_loader: - val_images, val_labels = val_data['img'].to(device), val_data['label'].to(device) - y_pred = torch.cat([y_pred, model(val_images)], dim=0) - y = torch.cat([y, val_labels], dim=0) + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + y_pred = torch.tensor([], dtype=torch.float32, device=device) + y = torch.tensor([], dtype=torch.long, device=device) + for val_data in val_loader: + val_images, val_labels = val_data['img'].to(device), val_data['label'].to(device) + y_pred = torch.cat([y_pred, model(val_images)], dim=0) + y = torch.cat([y, val_labels], dim=0) - acc_value = torch.eq(y_pred.argmax(dim=1), y) - acc_metric = acc_value.sum().item() / len(acc_value) - auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, add_softmax=True) - if acc_metric > best_metric: - best_metric = acc_metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), 'best_metric_model.pth') - print('saved new best metric model') - print('current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}'.format( - epoch + 1, acc_metric, auc_metric, best_metric, best_metric_epoch)) - writer.add_scalar('val_accuracy', acc_metric, epoch + 1) -print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) -writer.close() + acc_value = torch.eq(y_pred.argmax(dim=1), y) + acc_metric = acc_value.sum().item() / len(acc_value) + auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, add_softmax=True) + if acc_metric > best_metric: + best_metric = acc_metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), 'best_metric_model.pth') + print('saved new best metric model') + print('current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}'.format( + epoch + 1, acc_metric, auc_metric, best_metric, best_metric_epoch)) + writer.add_scalar('val_accuracy', acc_metric, epoch + 1) + print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) + writer.close() diff --git a/examples/classification_3d_ignite/densenet_evaluation_array.py b/examples/classification_3d_ignite/densenet_evaluation_array.py index c693b25b29..f0c3bcbcd0 100644 --- a/examples/classification_3d_ignite/densenet_evaluation_array.py +++ b/examples/classification_3d_ignite/densenet_evaluation_array.py @@ -22,73 +22,74 @@ from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor from monai.handlers import StatsHandler, ClassificationSaver, CheckpointLoader -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) -# define transforms for image -val_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) -# define nifti dataset -val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) -# create DenseNet121 -net = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -) -device = torch.device('cuda:0') + # define transforms for image + val_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) + # define nifti dataset + val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) + # create DenseNet121 + net = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ) + device = torch.device('cuda:0') -metric_name = 'Accuracy' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: Accuracy()} + metric_name = 'Accuracy' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: Accuracy()} -def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch[0], batch[1]), device, non_blocking) + def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[0], batch[1]), device, non_blocking) -# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None # no need to print loss value, so disable per iteration output -) -val_stats_handler.attach(evaluator) + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) -# for the array data format, assume the 3rd item of batch data is the meta_data -prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], - output_transform=lambda output: output[0].argmax(1)) -prediction_saver.attach(evaluator) + # for the array data format, assume the 3rd item of batch data is the meta_data + prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], + output_transform=lambda output: output[0].argmax(1)) + prediction_saver.attach(evaluator) -# the model was trained by "densenet_training_array" example -CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + # the model was trained by "densenet_training_array" example + CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) -# create a validation data loader -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -state = evaluator.run(val_loader) + state = evaluator.run(val_loader) diff --git a/examples/classification_3d_ignite/densenet_evaluation_dict.py b/examples/classification_3d_ignite/densenet_evaluation_dict.py index 51dae84d57..7c8a0880cf 100644 --- a/examples/classification_3d_ignite/densenet_evaluation_dict.py +++ b/examples/classification_3d_ignite/densenet_evaluation_dict.py @@ -21,75 +21,76 @@ from monai.handlers import StatsHandler, CheckpointLoader, ClassificationSaver from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, ToTensord -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) -val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) + val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] -# define transforms for image -val_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - ToTensord(keys=['img']) -]) + # define transforms for image + val_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + ToTensord(keys=['img']) + ]) -# create DenseNet121 -net = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -) -device = torch.device('cuda:0') + # create DenseNet121 + net = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ) + device = torch.device('cuda:0') -def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) -metric_name = 'Accuracy' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: Accuracy()} -# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + metric_name = 'Accuracy' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: Accuracy()} + # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None # no need to print loss value, so disable per iteration output -) -val_stats_handler.attach(evaluator) + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) -# for the array data format, assume the 3rd item of batch data is the meta_data -prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator', - batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']}, - output_transform=lambda output: output[0].argmax(1)) -prediction_saver.attach(evaluator) + # for the array data format, assume the 3rd item of batch data is the meta_data + prediction_saver = ClassificationSaver(output_dir='tempdir', name='evaluator', + batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj']}, + output_transform=lambda output: output[0].argmax(1)) + prediction_saver.attach(evaluator) -# the model was trained by "densenet_training_dict" example -CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + # the model was trained by "densenet_training_dict" example + CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -state = evaluator.run(val_loader) + state = evaluator.run(val_loader) diff --git a/examples/classification_3d_ignite/densenet_training_array.py b/examples/classification_3d_ignite/densenet_training_array.py index f3c2bcb301..8abb3ab219 100644 --- a/examples/classification_3d_ignite/densenet_training_array.py +++ b/examples/classification_3d_ignite/densenet_training_array.py @@ -23,131 +23,132 @@ from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, RandRotate90, ToTensor from monai.handlers import StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) - -# define transforms -train_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - RandRotate90(), - ToTensor() -]) -val_transforms = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) - -# define nifti dataset, data loader -check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) -check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(check_loader) -print(type(im), im.shape, label) - -# create DenseNet121, CrossEntropyLoss and Adam optimizer -net = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -) -loss = torch.nn.CrossEntropyLoss() -lr = 1e-5 -opt = torch.optim.Adam(net.parameters(), lr) -device = torch.device('cuda:0') - -# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, -# user can add output_transform to return other values, like: y_pred, y, etc. -trainer = create_supervised_trainer(net, opt, loss, device, False) - -# adding checkpoint handler to save models (network params and optimizer stats) during training -checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, - handler=checkpoint_handler, - to_save={'net': net, 'opt': opt}) - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't set metrics for trainer here, so just print loss, user can also customize print functions -# and can use output_transform to convert engine.state.output if it's not loss value -train_stats_handler = StatsHandler(name='trainer') -train_stats_handler.attach(trainer) - -# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler -train_tensorboard_stats_handler = TensorBoardStatsHandler() -train_tensorboard_stats_handler.attach(trainer) - -# set parameters for validation -validation_every_n_epochs = 1 - -metric_name = 'Accuracy' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: Accuracy()} -# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True) - -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_stats_handler.attach(evaluator) - -# add handler to record metrics to TensorBoard at every epoch -val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_tensorboard_stats_handler.attach(evaluator) - -# add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# create a validation data loader -val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) - - -# create a training data loader -train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) - -train_epochs = 30 -state = trainer.run(train_loader, train_epochs) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) + + # define transforms + train_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + RandRotate90(), + ToTensor() + ]) + val_transforms = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) + + # define nifti dataset, data loader + check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) + check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + im, label = monai.utils.misc.first(check_loader) + print(type(im), im.shape, label) + + # create DenseNet121, CrossEntropyLoss and Adam optimizer + net = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ) + loss = torch.nn.CrossEntropyLoss() + lr = 1e-5 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device('cuda:0') + + # ignite trainer expects batch=(img, label) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + trainer = create_supervised_trainer(net, opt, loss, device, False) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) + trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not loss value + train_stats_handler = StatsHandler(name='trainer') + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + # set parameters for validation + validation_every_n_epochs = 1 + + metric_name = 'Accuracy' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: Accuracy()} + # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # create a validation data loader + val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) + + + @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) + def run_validation(engine): + evaluator.run(val_loader) + + + # create a training data loader + train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + + train_epochs = 30 + state = trainer.run(train_loader, train_epochs) diff --git a/examples/classification_3d_ignite/densenet_training_dict.py b/examples/classification_3d_ignite/densenet_training_dict.py index cd5339bce7..6493cd250a 100644 --- a/examples/classification_3d_ignite/densenet_training_dict.py +++ b/examples/classification_3d_ignite/densenet_training_dict.py @@ -22,140 +22,141 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, RandRotate90d, ToTensord from monai.handlers import StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric, ROCAUC -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ -images = [ - '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', - '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' -] -# 2 binary labels for gender classification: man and woman -labels = np.array([ - 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 -]) -train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] -val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] - -# define transforms for image -train_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]), - ToTensord(keys=['img']) -]) -val_transforms = Compose([ - LoadNiftid(keys=['img']), - AddChanneld(keys=['img']), - ScaleIntensityd(keys=['img']), - Resized(keys=['img'], spatial_size=(96, 96, 96)), - ToTensord(keys=['img']) -]) - -# define dataset, data loader -check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(check_loader) -print(check_data['img'].shape, check_data['label']) - -# create DenseNet121, CrossEntropyLoss and Adam optimizer -net = monai.networks.nets.densenet.densenet121( - spatial_dims=3, - in_channels=1, - out_channels=2, -) -loss = torch.nn.CrossEntropyLoss() -lr = 1e-5 -opt = torch.optim.Adam(net.parameters(), lr) -device = torch.device('cuda:0') - - -# ignite trainer expects batch=(img, label) and returns output=loss at every iteration, -# user can add output_transform to return other values, like: y_pred, y, etc. -def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch['img'], batch['label']), device, non_blocking) - - -trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) - -# adding checkpoint handler to save models (network params and optimizer stats) during training -checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, - handler=checkpoint_handler, - to_save={'net': net, 'opt': opt}) - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't set metrics for trainer here, so just print loss, user can also customize print functions -# and can use output_transform to convert engine.state.output if it's not loss value -train_stats_handler = StatsHandler(name='trainer') -train_stats_handler.attach(trainer) - -# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler -train_tensorboard_stats_handler = TensorBoardStatsHandler() -train_tensorboard_stats_handler.attach(trainer) - -# set parameters for validation -validation_every_n_epochs = 1 - -metric_name = 'Accuracy' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: Accuracy(), 'AUC': ROCAUC(to_onehot_y=True, add_softmax=True)} -# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) - -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_stats_handler.attach(evaluator) - -# add handler to record metrics to TensorBoard at every epoch -val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_tensorboard_stats_handler.attach(evaluator) - -# add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) - - -# create a training data loader -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) - -train_epochs = 30 -state = trainer.run(train_loader, train_epochs) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/ + images = [ + '/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz', + '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz' + ] + # 2 binary labels for gender classification: man and woman + labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 + ]) + train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] + val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + + # define transforms for image + train_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, spatial_axes=[0, 2]), + ToTensord(keys=['img']) + ]) + val_transforms = Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + ScaleIntensityd(keys=['img']), + Resized(keys=['img'], spatial_size=(96, 96, 96)), + ToTensord(keys=['img']) + ]) + + # define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + check_data = monai.utils.misc.first(check_loader) + print(check_data['img'].shape, check_data['label']) + + # create DenseNet121, CrossEntropyLoss and Adam optimizer + net = monai.networks.nets.densenet.densenet121( + spatial_dims=3, + in_channels=1, + out_channels=2, + ) + loss = torch.nn.CrossEntropyLoss() + lr = 1e-5 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device('cuda:0') + + + # ignite trainer expects batch=(img, label) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + + trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) + trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not loss value + train_stats_handler = StatsHandler(name='trainer') + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + # set parameters for validation + validation_every_n_epochs = 1 + + metric_name = 'Accuracy' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: Accuracy(), 'AUC': ROCAUC(to_onehot_y=True, add_softmax=True)} + # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + + + @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) + def run_validation(engine): + evaluator.run(val_loader) + + + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available()) + + train_epochs = 30 + state = trainer.run(train_loader, train_epochs) diff --git a/examples/notebooks/unet_segmentation_3d_ignite.ipynb b/examples/notebooks/unet_segmentation_3d_ignite.ipynb index 6088a3cba4..30861be6a9 100644 --- a/examples/notebooks/unet_segmentation_3d_ignite.ipynb +++ b/examples/notebooks/unet_segmentation_3d_ignite.ipynb @@ -200,7 +200,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Add Vadliation every N epochs" + "## Add Validation every N epochs" ] }, { diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index fbcd03e382..3d0dd937eb 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -26,58 +26,59 @@ from monai.data import create_test_image_3d, sliding_window_inference, NiftiSaver, NiftiDataset from monai.metrics import compute_meandice -config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -# define transforms for image and segmentation -imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) -segtrans = Compose([AddChannel(), ToTensor()]) -val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) -# sliding window inference for one image at every iteration -val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) + # define transforms for image and segmentation + imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) + segtrans = Compose([AddChannel(), ToTensor()]) + val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) + # sliding window inference for one image at every iteration + val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) -device = torch.device('cuda:0') -model = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -).to(device) + device = torch.device('cuda:0') + model = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) -model.load_state_dict(torch.load('best_metric_model.pth')) -model.eval() -with torch.no_grad(): - metric_sum = 0. - metric_count = 0 - saver = NiftiSaver(output_dir='./output') - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, - to_onehot_y=False, add_sigmoid=True) - metric_count += len(value) - metric_sum += value.sum().item() - val_outputs = (val_outputs.sigmoid() >= 0.5).float() - saver.save_batch(val_outputs, val_data[2]) - metric = metric_sum / metric_count - print('evaluation metric:', metric) -shutil.rmtree(tempdir) + model.load_state_dict(torch.load('best_metric_model.pth')) + model.eval() + with torch.no_grad(): + metric_sum = 0. + metric_count = 0 + saver = NiftiSaver(output_dir='./output') + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, + to_onehot_y=False, add_sigmoid=True) + metric_count += len(value) + metric_sum += value.sum().item() + val_outputs = (val_outputs.sigmoid() >= 0.5).float() + saver.save_batch(val_outputs, val_data[2]) + metric = metric_sum / metric_count + print('evaluation metric:', metric) + shutil.rmtree(tempdir) diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index 308d22bc78..c756acbc2f 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -26,65 +26,66 @@ from monai.networks.nets import UNet from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] -# define transforms for image and segmentation -val_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - ToTensord(keys=['img', 'seg']) -]) -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -# sliding window inference need to input 1 image in every iteration -val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) + # define transforms for image and segmentation + val_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + ToTensord(keys=['img', 'seg']) + ]) + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + # sliding window inference need to input 1 image in every iteration + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) -device = torch.device('cuda:0') -model = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -).to(device) + device = torch.device('cuda:0') + model = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) -model.load_state_dict(torch.load('best_metric_model.pth')) -model.eval() -with torch.no_grad(): - metric_sum = 0. - metric_count = 0 - saver = NiftiSaver(output_dir='./output') - for val_data in val_loader: - val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, - to_onehot_y=False, add_sigmoid=True) - metric_count += len(value) - metric_sum += value.sum().item() - val_outputs = (val_outputs.sigmoid() >= 0.5).float() - saver.save_batch(val_outputs, {'filename_or_obj': val_data['img.filename_or_obj'], - 'affine': val_data['img.affine']}) - metric = metric_sum / metric_count - print('evaluation metric:', metric) -shutil.rmtree(tempdir) + model.load_state_dict(torch.load('best_metric_model.pth')) + model.eval() + with torch.no_grad(): + metric_sum = 0. + metric_count = 0 + saver = NiftiSaver(output_dir='./output') + for val_data in val_loader: + val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, + to_onehot_y=False, add_sigmoid=True) + metric_count += len(value) + metric_sum += value.sum().item() + val_outputs = (val_outputs.sigmoid() >= 0.5).float() + saver.save_batch(val_outputs, {'filename_or_obj': val_data['img.filename_or_obj'], + 'affine': val_data['img.affine']}) + metric = metric_sum / metric_count + print('evaluation metric:', metric) + shutil.rmtree(tempdir) diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index c20a0e03bb..a69dcc7901 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -27,134 +27,135 @@ from monai.metrics import compute_meandice from monai.visualize.img2tensorboard import plot_2d_or_3d_image -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# create a temporary directory and 40 random image, mask paris -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + # create a temporary directory and 40 random image, mask paris + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -# define transforms for image and segmentation -train_imtrans = Compose([ - ScaleIntensity(), - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - RandRotate90(prob=0.5, spatial_axes=(0, 2)), - ToTensor() -]) -train_segtrans = Compose([ - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - RandRotate90(prob=0.5, spatial_axes=(0, 2)), - ToTensor() -]) -val_imtrans = Compose([ - ScaleIntensity(), - AddChannel(), - ToTensor() -]) -val_segtrans = Compose([ - AddChannel(), - ToTensor() -]) + # define transforms for image and segmentation + train_imtrans = Compose([ + ScaleIntensity(), + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + RandRotate90(prob=0.5, spatial_axes=(0, 2)), + ToTensor() + ]) + train_segtrans = Compose([ + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + RandRotate90(prob=0.5, spatial_axes=(0, 2)), + ToTensor() + ]) + val_imtrans = Compose([ + ScaleIntensity(), + AddChannel(), + ToTensor() + ]) + val_segtrans = Compose([ + AddChannel(), + ToTensor() + ]) -# define nifti dataset, data loader -check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) -check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(check_loader) -print(im.shape, seg.shape) + # define nifti dataset, data loader + check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) + check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) + im, seg = monai.utils.misc.first(check_loader) + print(im.shape, seg.shape) -# create a training data loader -train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) -train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) -val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) + # create a training data loader + train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) + train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) -# create UNet, DiceLoss and Adam optimizer -device = torch.device('cuda:0') -model = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -).to(device) -loss_function = monai.losses.DiceLoss(do_sigmoid=True) -optimizer = torch.optim.Adam(model.parameters(), 1e-3) + # create UNet, DiceLoss and Adam optimizer + device = torch.device('cuda:0') + model = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) + loss_function = monai.losses.DiceLoss(do_sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) -# start a typical PyTorch training -val_interval = 2 -best_metric = -1 -best_metric_epoch = -1 -epoch_loss_values = list() -metric_values = list() -writer = SummaryWriter() -for epoch in range(5): - print('-' * 10) - print('epoch {}/{}'.format(epoch + 1, 5)) - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data[0].to(device), batch_data[1].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) - writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + metric_values = list() + writer = SummaryWriter() + for epoch in range(5): + print('-' * 10) + print('epoch {}/{}'.format(epoch + 1, 5)) + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data[0].to(device), batch_data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) + writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - metric_sum = 0. - metric_count = 0 - val_images = None - val_labels = None - val_outputs = None - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, - to_onehot_y=False, add_sigmoid=True) - metric_count += len(value) - metric_sum += value.sum().item() - metric = metric_sum / metric_count - metric_values.append(metric) - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), 'best_metric_model.pth') - print('saved new best metric model') - print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format( - epoch + 1, metric, best_metric, best_metric_epoch)) - writer.add_scalar('val_mean_dice', metric, epoch + 1) - # plot the last model output as GIF image in TensorBoard with the corresponding image and label - plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') - plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') - plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') -shutil.rmtree(tempdir) -print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) -writer.close() + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + metric_sum = 0. + metric_count = 0 + val_images = None + val_labels = None + val_outputs = None + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, + to_onehot_y=False, add_sigmoid=True) + metric_count += len(value) + metric_sum += value.sum().item() + metric = metric_sum / metric_count + metric_values.append(metric) + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), 'best_metric_model.pth') + print('saved new best metric model') + print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format( + epoch + 1, metric, best_metric, best_metric_epoch)) + writer.add_scalar('val_mean_dice', metric, epoch + 1) + # plot the last model output as GIF image in TensorBoard with the corresponding image and label + plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') + plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') + plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') + shutil.rmtree(tempdir) + print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) + writer.close() diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index c0d0d6a37b..b3f34f747f 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -28,133 +28,134 @@ from monai.metrics import compute_meandice from monai.visualize import plot_2d_or_3d_image -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# create a temporary directory and 40 random image, mask paris -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + # create a temporary directory and 40 random image, mask paris + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) -images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] + images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] + val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] -# define transforms for image and segmentation -train_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), - RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=['img', 'seg']) -]) -val_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - ToTensord(keys=['img', 'seg']) -]) + # define transforms for image and segmentation + train_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), + RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), + ToTensord(keys=['img', 'seg']) + ]) + val_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + ToTensord(keys=['img', 'seg']) + ]) -# define dataset, data loader -check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training -check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(check_loader) -print(check_data['img'].shape, check_data['seg'].shape) + # define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + check_data = monai.utils.misc.first(check_loader) + print(check_data['img'].shape, check_data['seg'].shape) -# create a training data loader -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, - collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, + collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) -# create UNet, DiceLoss and Adam optimizer -device = torch.device('cuda:0') -model = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -).to(device) -loss_function = monai.losses.DiceLoss(do_sigmoid=True) -optimizer = torch.optim.Adam(model.parameters(), 1e-3) + # create UNet, DiceLoss and Adam optimizer + device = torch.device('cuda:0') + model = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) + loss_function = monai.losses.DiceLoss(do_sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) -# start a typical PyTorch training -val_interval = 2 -best_metric = -1 -best_metric_epoch = -1 -epoch_loss_values = list() -metric_values = list() -writer = SummaryWriter() -for epoch in range(5): - print('-' * 10) - print('epoch {}/{}'.format(epoch + 1, 5)) - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) - writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + metric_values = list() + writer = SummaryWriter() + for epoch in range(5): + print('-' * 10) + print('epoch {}/{}'.format(epoch + 1, 5)) + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) + writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - metric_sum = 0. - metric_count = 0 - val_images = None - val_labels = None - val_outputs = None - for val_data in val_loader: - val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, - to_onehot_y=False, add_sigmoid=True) - metric_count += len(value) - metric_sum += value.sum().item() - metric = metric_sum / metric_count - metric_values.append(metric) - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), 'best_metric_model.pth') - print('saved new best metric model') - print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format( - epoch + 1, metric, best_metric, best_metric_epoch)) - writer.add_scalar('val_mean_dice', metric, epoch + 1) - # plot the last model output as GIF image in TensorBoard with the corresponding image and label - plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') - plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') - plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') -shutil.rmtree(tempdir) -print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) -writer.close() + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + metric_sum = 0. + metric_count = 0 + val_images = None + val_labels = None + val_outputs = None + for val_data in val_loader: + val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, + to_onehot_y=False, add_sigmoid=True) + metric_count += len(value) + metric_sum += value.sum().item() + metric = metric_sum / metric_count + metric_values.append(metric) + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), 'best_metric_model.pth') + print('saved new best metric model') + print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format( + epoch + 1, metric, best_metric, best_metric_epoch)) + writer.add_scalar('val_mean_dice', metric, epoch + 1) + # plot the last model output as GIF image in TensorBoard with the corresponding image and label + plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') + plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') + plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') + shutil.rmtree(tempdir) + print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) + writer.close() diff --git a/examples/segmentation_3d_ignite/unet_evaluation_array.py b/examples/segmentation_3d_ignite/unet_evaluation_array.py index d3e2c28ef4..df03a4e724 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_array.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_array.py @@ -28,76 +28,77 @@ from monai.networks.nets import UNet from monai.networks.utils import predict_segmentation -config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) - -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) - -# define transforms for image and segmentation -imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) -segtrans = Compose([AddChannel(), ToTensor()]) -ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) - -device = torch.device('cuda:0') -net = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -) -net.to(device) - -# define sliding window size and batch size for windows inference -roi_size = (96, 96, 96) -sw_batch_size = 4 - - -def _sliding_window_processor(engine, batch): - net.eval() - with torch.no_grad(): - val_images, val_labels = batch[0].to(device), batch[1].to(device) - seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) - return seg_probs, val_labels - - -evaluator = Engine(_sliding_window_processor) - -# add evaluation metric to the evaluator engine -MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None # no need to print loss value, so disable per iteration output -) -val_stats_handler.attach(evaluator) - -# for the array data format, assume the 3rd item of batch data is the meta_data -file_saver = SegmentationSaver( - output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', - batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) -file_saver.attach(evaluator) - -# the model was trained by "unet_training_array" example -ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) -ckpt_saver.attach(evaluator) - -# sliding window inference for one image at every iteration -loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) -state = evaluator.run(loader) -shutil.rmtree(tempdir) +if __name__ == '__main__': + config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + + # define transforms for image and segmentation + imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) + segtrans = Compose([AddChannel(), ToTensor()]) + ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) + + device = torch.device('cuda:0') + net = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + net.to(device) + + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + + + def _sliding_window_processor(engine, batch): + net.eval() + with torch.no_grad(): + val_images, val_labels = batch[0].to(device), batch[1].to(device) + seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + return seg_probs, val_labels + + + evaluator = Engine(_sliding_window_processor) + + # add evaluation metric to the evaluator engine + MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) + + # for the array data format, assume the 3rd item of batch data is the meta_data + file_saver = SegmentationSaver( + output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', + batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) + file_saver.attach(evaluator) + + # the model was trained by "unet_training_array" example + ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) + ckpt_saver.attach(evaluator) + + # sliding window inference for one image at every iteration + loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) + state = evaluator.run(loader) + shutil.rmtree(tempdir) diff --git a/examples/segmentation_3d_ignite/unet_evaluation_dict.py b/examples/segmentation_3d_ignite/unet_evaluation_dict.py index 2660342d0a..cb0b5af662 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_dict.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_dict.py @@ -28,80 +28,81 @@ from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord from monai.handlers import SegmentationSaver, CheckpointLoader, StatsHandler, MeanDice -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) - -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] - -# define transforms for image and segmentation -val_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - ToTensord(keys=['img', 'seg']) -]) -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - -device = torch.device('cuda:0') -net = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -) -net.to(device) - -# define sliding window size and batch size for windows inference -roi_size = (96, 96, 96) -sw_batch_size = 4 - - -def _sliding_window_processor(engine, batch): - net.eval() - with torch.no_grad(): - val_images, val_labels = batch['img'].to(device), batch['seg'].to(device) - seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) - return seg_probs, val_labels - - -evaluator = Engine(_sliding_window_processor) - -# add evaluation metric to the evaluator engine -MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None # no need to print loss value, so disable per iteration output -) -val_stats_handler.attach(evaluator) - -# convert the necessary metadata from batch data -SegmentationSaver(output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', - batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'], - 'affine': batch['img.affine']}, - output_transform=lambda output: predict_segmentation(output[0])).attach(evaluator) -# the model was trained by "unet_training_dict" example -CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}).attach(evaluator) - -# sliding window inference for one image at every iteration -val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -state = evaluator.run(val_loader) -shutil.rmtree(tempdir) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] + + # define transforms for image and segmentation + val_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + ToTensord(keys=['img', 'seg']) + ]) + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + + device = torch.device('cuda:0') + net = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + net.to(device) + + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + + + def _sliding_window_processor(engine, batch): + net.eval() + with torch.no_grad(): + val_images, val_labels = batch['img'].to(device), batch['seg'].to(device) + seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + return seg_probs, val_labels + + + evaluator = Engine(_sliding_window_processor) + + # add evaluation metric to the evaluator engine + MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) + + # convert the necessary metadata from batch data + SegmentationSaver(output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', + batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'], + 'affine': batch['img.affine']}, + output_transform=lambda output: predict_segmentation(output[0])).attach(evaluator) + # the model was trained by "unet_training_dict" example + CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}).attach(evaluator) + + # sliding window inference for one image at every iteration + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + state = evaluator.run(val_loader) + shutil.rmtree(tempdir) diff --git a/examples/segmentation_3d_ignite/unet_training_array.py b/examples/segmentation_3d_ignite/unet_training_array.py index 63f223a6cb..d0b2ce5998 100644 --- a/examples/segmentation_3d_ignite/unet_training_array.py +++ b/examples/segmentation_3d_ignite/unet_training_array.py @@ -29,139 +29,140 @@ StatsHandler, TensorBoardStatsHandler, TensorBoardImageHandler, MeanDice, stopping_fn_from_metric from monai.networks.utils import predict_segmentation -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -# create a temporary directory and 40 random image, mask paris -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) - -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) - -# define transforms for image and segmentation -train_imtrans = Compose([ - ScaleIntensity(), - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - ToTensor() -]) -train_segtrans = Compose([ - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - ToTensor() -]) -val_imtrans = Compose([ - ScaleIntensity(), - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) -val_segtrans = Compose([ - AddChannel(), - Resize((96, 96, 96)), - ToTensor() -]) - -# define nifti dataset, data loader -check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) -check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(check_loader) -print(im.shape, seg.shape) - -# create a training data loader -train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) -train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) -val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) - -# create UNet, DiceLoss and Adam optimizer -net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -) -loss = monai.losses.DiceLoss(do_sigmoid=True) -lr = 1e-3 -opt = torch.optim.Adam(net.parameters(), lr) -device = torch.device('cuda:0') - -# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, -# user can add output_transform to return other values, like: y_pred, y, etc. -trainer = create_supervised_trainer(net, opt, loss, device, False) - -# adding checkpoint handler to save models (network params and optimizer stats) during training -checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, - handler=checkpoint_handler, - to_save={'net': net, 'opt': opt}) - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't set metrics for trainer here, so just print loss, user can also customize print functions -# and can use output_transform to convert engine.state.output if it's not a loss value -train_stats_handler = StatsHandler(name='trainer') -train_stats_handler.attach(trainer) - -# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler -train_tensorboard_stats_handler = TensorBoardStatsHandler() -train_tensorboard_stats_handler.attach(trainer) - -validation_every_n_epochs = 1 -# Set parameters for validation -metric_name = 'Mean_Dice' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} - -# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True) - - -@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) -def run_validation(engine): - evaluator.run(val_loader) - - -# add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_stats_handler.attach(evaluator) - -# add handler to record metrics to TensorBoard at every validation epoch -val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_tensorboard_stats_handler.attach(evaluator) - -# add handler to draw the first image and the corresponding label and model output in the last batch -# here we draw the 3D output as GIF format along Depth axis, at every validation epoch -val_tensorboard_image_handler = TensorBoardImageHandler( - batch_transform=lambda batch: (batch[0], batch[1]), - output_transform=lambda output: predict_segmentation(output[0]), - global_iter_transform=lambda x: trainer.state.epoch -) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) - -train_epochs = 30 -state = trainer.run(train_loader, train_epochs) -shutil.rmtree(tempdir) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + # create a temporary directory and 40 random image, mask paris + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + + images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + + # define transforms for image and segmentation + train_imtrans = Compose([ + ScaleIntensity(), + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + ToTensor() + ]) + train_segtrans = Compose([ + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + ToTensor() + ]) + val_imtrans = Compose([ + ScaleIntensity(), + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) + val_segtrans = Compose([ + AddChannel(), + Resize((96, 96, 96)), + ToTensor() + ]) + + # define nifti dataset, data loader + check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) + check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) + im, seg = monai.utils.misc.first(check_loader) + print(im.shape, seg.shape) + + # create a training data loader + train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) + train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) + val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + + # create UNet, DiceLoss and Adam optimizer + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + loss = monai.losses.DiceLoss(do_sigmoid=True) + lr = 1e-3 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device('cuda:0') + + # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + trainer = create_supervised_trainer(net, opt, loss, device, False) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) + trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not a loss value + train_stats_handler = StatsHandler(name='trainer') + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + validation_every_n_epochs = 1 + # Set parameters for validation + metric_name = 'Mean_Dice' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} + + # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True) + + + @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) + def run_validation(engine): + evaluator.run(val_loader) + + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every validation epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add handler to draw the first image and the corresponding label and model output in the last batch + # here we draw the 3D output as GIF format along Depth axis, at every validation epoch + val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch + ) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + + train_epochs = 30 + state = trainer.run(train_loader, train_epochs) + shutil.rmtree(tempdir) diff --git a/examples/segmentation_3d_ignite/unet_training_dict.py b/examples/segmentation_3d_ignite/unet_training_dict.py index 89e04dd22b..e43882fa6f 100644 --- a/examples/segmentation_3d_ignite/unet_training_dict.py +++ b/examples/segmentation_3d_ignite/unet_training_dict.py @@ -30,143 +30,144 @@ from monai.data import create_test_image_3d, list_data_collate from monai.networks.utils import predict_segmentation -monai.config.print_config() -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - -# create a temporary directory and 40 random image, mask paris -tempdir = tempfile.mkdtemp() -print('generating synthetic data to {} (this may take a while)'.format(tempdir)) -for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) - -images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) -segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] -val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] - -# define transforms for image and segmentation -train_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), - RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=['img', 'seg']) -]) -val_transforms = Compose([ - LoadNiftid(keys=['img', 'seg']), - AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), - ScaleIntensityd(keys=['img', 'seg']), - ToTensord(keys=['img', 'seg']) -]) - -# define dataset, data loader -check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training -check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(check_loader) -print(check_data['img'].shape, check_data['seg'].shape) - -# create a training data loader -train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training -train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, - collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) -# create a validation data loader -val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) - -# create UNet, DiceLoss and Adam optimizer -net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, -) -loss = monai.losses.DiceLoss(do_sigmoid=True) -lr = 1e-3 -opt = torch.optim.Adam(net.parameters(), lr) -device = torch.device('cuda:0') - -# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, -# user can add output_transform to return other values, like: y_pred, y, etc. -def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) - - -trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) - -# adding checkpoint handler to save models (network params and optimizer stats) during training -checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, - handler=checkpoint_handler, - to_save={'net': net, 'opt': opt}) - -# StatsHandler prints loss at every iteration and print metrics at every epoch, -# we don't set metrics for trainer here, so just print loss, user can also customize print functions -# and can use output_transform to convert engine.state.output if it's not loss value -train_stats_handler = StatsHandler(name='trainer') -train_stats_handler.attach(trainer) - -# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler -train_tensorboard_stats_handler = TensorBoardStatsHandler() -train_tensorboard_stats_handler.attach(trainer) - -validation_every_n_iters = 5 -# set parameters for validation -metric_name = 'Mean_Dice' -# add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} - -# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, -# user can add output_transform to return other values -evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) - - -@trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) -def run_validation(engine): - evaluator.run(val_loader) - - -# add early stopping handler to evaluator -early_stopper = EarlyStopping(patience=4, - score_function=stopping_fn_from_metric(metric_name), - trainer=trainer) -evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - -# add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler( - name='evaluator', - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer -val_stats_handler.attach(evaluator) - -# add handler to record metrics to TensorBoard at every validation epoch -val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.iteration) # fetch global iteration number from trainer -val_tensorboard_stats_handler.attach(evaluator) - -# add handler to draw the first image and the corresponding label and model output in the last batch -# here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. -val_tensorboard_image_handler = TensorBoardImageHandler( - batch_transform=lambda batch: (batch['img'], batch['seg']), - output_transform=lambda output: predict_segmentation(output[0]), - global_iter_transform=lambda x: trainer.state.epoch -) -evaluator.add_event_handler( - event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) - -train_epochs = 5 -state = trainer.run(train_loader, train_epochs) -shutil.rmtree(tempdir) +if __name__ == '__main__': + monai.config.print_config() + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + # create a temporary directory and 40 random image, mask paris + tempdir = tempfile.mkdtemp() + print('generating synthetic data to {} (this may take a while)'.format(tempdir)) + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + + images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) + segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])] + val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-20:], segs[-20:])] + + # define transforms for image and segmentation + train_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), + RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), + ToTensord(keys=['img', 'seg']) + ]) + val_transforms = Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + ScaleIntensityd(keys=['img', 'seg']), + ToTensord(keys=['img', 'seg']) + ]) + + # define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + check_data = monai.utils.misc.first(check_loader) + print(check_data['img'].shape, check_data['seg'].shape) + + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, + collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) + + # create UNet, DiceLoss and Adam optimizer + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + loss = monai.losses.DiceLoss(do_sigmoid=True) + lr = 1e-3 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device('cuda:0') + + # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) + + + trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) + trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not loss value + train_stats_handler = StatsHandler(name='trainer') + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + validation_every_n_iters = 5 + # set parameters for validation + metric_name = 'Mean_Dice' + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} + + # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + + + @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) + def run_validation(engine): + evaluator.run(val_loader) + + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name='evaluator', + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every validation epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.iteration) # fetch global iteration number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add handler to draw the first image and the corresponding label and model output in the last batch + # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. + val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch['img'], batch['seg']), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch + ) + evaluator.add_event_handler( + event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) + + train_epochs = 5 + state = trainer.run(train_loader, train_epochs) + shutil.rmtree(tempdir) From d183167baa5f4328aa806ab79b929d309fd021ce Mon Sep 17 00:00:00 2001 From: Guy Leroy Date: Sun, 26 Apr 2020 10:21:30 +0100 Subject: [PATCH 2/3] Fixed flake8 E301 --- examples/classification_3d_ignite/densenet_training_dict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/classification_3d_ignite/densenet_training_dict.py b/examples/classification_3d_ignite/densenet_training_dict.py index 6493cd250a..8323e44ce8 100644 --- a/examples/classification_3d_ignite/densenet_training_dict.py +++ b/examples/classification_3d_ignite/densenet_training_dict.py @@ -94,6 +94,7 @@ # ignite trainer expects batch=(img, label) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) From eb6d48fb1f050cb8bf99329cf0abc5be1b31d938 Mon Sep 17 00:00:00 2001 From: Guy Leroy Date: Sun, 26 Apr 2020 17:00:04 +0100 Subject: [PATCH 3/3] Refactored main def in examples --- examples/classification_3d/densenet_evaluation_array.py | 5 ++++- examples/classification_3d/densenet_evaluation_dict.py | 5 ++++- examples/classification_3d/densenet_training_array.py | 5 ++++- examples/classification_3d/densenet_training_dict.py | 5 ++++- .../classification_3d_ignite/densenet_evaluation_array.py | 5 ++++- .../classification_3d_ignite/densenet_evaluation_dict.py | 5 ++++- examples/classification_3d_ignite/densenet_training_array.py | 5 ++++- examples/classification_3d_ignite/densenet_training_dict.py | 5 ++++- examples/segmentation_3d/unet_evaluation_array.py | 5 ++++- examples/segmentation_3d/unet_evaluation_dict.py | 5 ++++- examples/segmentation_3d/unet_training_array.py | 5 ++++- examples/segmentation_3d/unet_training_dict.py | 5 ++++- examples/segmentation_3d_ignite/unet_evaluation_array.py | 5 ++++- examples/segmentation_3d_ignite/unet_evaluation_dict.py | 5 ++++- examples/segmentation_3d_ignite/unet_training_array.py | 5 ++++- examples/segmentation_3d_ignite/unet_training_dict.py | 5 ++++- 16 files changed, 64 insertions(+), 16 deletions(-) diff --git a/examples/classification_3d/densenet_evaluation_array.py b/examples/classification_3d/densenet_evaluation_array.py index af6547a785..1716133a09 100644 --- a/examples/classification_3d/densenet_evaluation_array.py +++ b/examples/classification_3d/densenet_evaluation_array.py @@ -19,7 +19,7 @@ from monai.data import NiftiDataset, CSVSaver from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -78,3 +78,6 @@ metric = num_correct / metric_count print('evaluation metric:', metric) saver.finalize() + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d/densenet_evaluation_dict.py b/examples/classification_3d/densenet_evaluation_dict.py index ba6f4d9b76..45bbd46b06 100644 --- a/examples/classification_3d/densenet_evaluation_dict.py +++ b/examples/classification_3d/densenet_evaluation_dict.py @@ -19,7 +19,7 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, ToTensord from monai.data import CSVSaver -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -79,3 +79,6 @@ metric = num_correct / metric_count print('evaluation metric:', metric) saver.finalize() + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d/densenet_training_array.py b/examples/classification_3d/densenet_training_array.py index e3c42a4288..0fbbda6caa 100644 --- a/examples/classification_3d/densenet_training_array.py +++ b/examples/classification_3d/densenet_training_array.py @@ -20,7 +20,7 @@ from monai.data import NiftiDataset from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, RandRotate90, ToTensor -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -143,3 +143,6 @@ writer.add_scalar('val_accuracy', metric, epoch + 1) print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) writer.close() + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d/densenet_training_dict.py b/examples/classification_3d/densenet_training_dict.py index c69d8da23e..ee8c944adb 100644 --- a/examples/classification_3d/densenet_training_dict.py +++ b/examples/classification_3d/densenet_training_dict.py @@ -20,7 +20,7 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, RandRotate90d, ToTensord from monai.metrics import compute_roc_auc -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -144,3 +144,6 @@ writer.add_scalar('val_accuracy', acc_metric, epoch + 1) print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) writer.close() + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d_ignite/densenet_evaluation_array.py b/examples/classification_3d_ignite/densenet_evaluation_array.py index f0c3bcbcd0..37b6425aa4 100644 --- a/examples/classification_3d_ignite/densenet_evaluation_array.py +++ b/examples/classification_3d_ignite/densenet_evaluation_array.py @@ -22,7 +22,7 @@ from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor from monai.handlers import StatsHandler, ClassificationSaver, CheckpointLoader -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -93,3 +93,6 @@ def prepare_batch(batch, device=None, non_blocking=False): val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d_ignite/densenet_evaluation_dict.py b/examples/classification_3d_ignite/densenet_evaluation_dict.py index 7c8a0880cf..f308efd94b 100644 --- a/examples/classification_3d_ignite/densenet_evaluation_dict.py +++ b/examples/classification_3d_ignite/densenet_evaluation_dict.py @@ -21,7 +21,7 @@ from monai.handlers import StatsHandler, CheckpointLoader, ClassificationSaver from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, ToTensord -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -94,3 +94,6 @@ def prepare_batch(batch, device=None, non_blocking=False): val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d_ignite/densenet_training_array.py b/examples/classification_3d_ignite/densenet_training_array.py index 8abb3ab219..41e1c008aa 100644 --- a/examples/classification_3d_ignite/densenet_training_array.py +++ b/examples/classification_3d_ignite/densenet_training_array.py @@ -23,7 +23,7 @@ from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, RandRotate90, ToTensor from monai.handlers import StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -152,3 +152,6 @@ def run_validation(engine): train_epochs = 30 state = trainer.run(train_loader, train_epochs) + +if __name__ == '__main__': + main() diff --git a/examples/classification_3d_ignite/densenet_training_dict.py b/examples/classification_3d_ignite/densenet_training_dict.py index 8323e44ce8..cd7f8b854e 100644 --- a/examples/classification_3d_ignite/densenet_training_dict.py +++ b/examples/classification_3d_ignite/densenet_training_dict.py @@ -22,7 +22,7 @@ from monai.transforms import Compose, LoadNiftid, AddChanneld, ScaleIntensityd, Resized, RandRotate90d, ToTensord from monai.handlers import StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric, ROCAUC -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -161,3 +161,6 @@ def run_validation(engine): train_epochs = 30 state = trainer.run(train_loader, train_epochs) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index 3d0dd937eb..6e974d4c8f 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -26,7 +26,7 @@ from monai.data import create_test_image_3d, sliding_window_inference, NiftiSaver, NiftiDataset from monai.metrics import compute_meandice -if __name__ == '__main__': +def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -82,3 +82,6 @@ metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index c756acbc2f..573504b594 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -26,7 +26,7 @@ from monai.networks.nets import UNet from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -89,3 +89,6 @@ metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index a69dcc7901..41bb719cdc 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -27,7 +27,7 @@ from monai.metrics import compute_meandice from monai.visualize.img2tensorboard import plot_2d_or_3d_image -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -159,3 +159,6 @@ shutil.rmtree(tempdir) print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) writer.close() + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index b3f34f747f..9e6958afda 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -28,7 +28,7 @@ from monai.metrics import compute_meandice from monai.visualize import plot_2d_or_3d_image -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -159,3 +159,6 @@ shutil.rmtree(tempdir) print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch)) writer.close() + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d_ignite/unet_evaluation_array.py b/examples/segmentation_3d_ignite/unet_evaluation_array.py index df03a4e724..26be6fa9e6 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_array.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_array.py @@ -28,7 +28,7 @@ from monai.networks.nets import UNet from monai.networks.utils import predict_segmentation -if __name__ == '__main__': +def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -102,3 +102,6 @@ def _sliding_window_processor(engine, batch): loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d_ignite/unet_evaluation_dict.py b/examples/segmentation_3d_ignite/unet_evaluation_dict.py index cb0b5af662..5d8a35a170 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_dict.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_dict.py @@ -28,7 +28,7 @@ from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord from monai.handlers import SegmentationSaver, CheckpointLoader, StatsHandler, MeanDice -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -106,3 +106,6 @@ def _sliding_window_processor(engine, batch): pin_memory=torch.cuda.is_available()) state = evaluator.run(val_loader) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d_ignite/unet_training_array.py b/examples/segmentation_3d_ignite/unet_training_array.py index d0b2ce5998..9bafd34b75 100644 --- a/examples/segmentation_3d_ignite/unet_training_array.py +++ b/examples/segmentation_3d_ignite/unet_training_array.py @@ -29,7 +29,7 @@ StatsHandler, TensorBoardStatsHandler, TensorBoardImageHandler, MeanDice, stopping_fn_from_metric from monai.networks.utils import predict_segmentation -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -166,3 +166,6 @@ def run_validation(engine): train_epochs = 30 state = trainer.run(train_loader, train_epochs) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main() diff --git a/examples/segmentation_3d_ignite/unet_training_dict.py b/examples/segmentation_3d_ignite/unet_training_dict.py index e43882fa6f..94d59c0d2d 100644 --- a/examples/segmentation_3d_ignite/unet_training_dict.py +++ b/examples/segmentation_3d_ignite/unet_training_dict.py @@ -30,7 +30,7 @@ from monai.data import create_test_image_3d, list_data_collate from monai.networks.utils import predict_segmentation -if __name__ == '__main__': +def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -171,3 +171,6 @@ def run_validation(engine): train_epochs = 5 state = trainer.run(train_loader, train_epochs) shutil.rmtree(tempdir) + +if __name__ == '__main__': + main()