diff --git a/CHANGELOG.md b/CHANGELOG.md index e67da5ff..188b419a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ## [unreleased] +🚨 **examples**: compare training with PyTorch ([#94](https://github.com/owkin/GrAIdient/pull/94))\ 🪜 **layer_2d:** ColorJitterHSV, Image & ImageTests ([#93](https://github.com/owkin/GrAIdient/pull/93))\ 🪜 **layer_2d:** Flip2D & config_kernels ([#92](https://github.com/owkin/GrAIdient/pull/92))\ 🔨 **layer_2d:** remove computeVQ ([#91](https://github.com/owkin/GrAIdient/pull/91))\ diff --git a/Docs/Examples/AutoEncoder.md b/Docs/Examples/AutoEncoder.md index 298ce671..eb9b1451 100644 --- a/Docs/Examples/AutoEncoder.md +++ b/Docs/Examples/AutoEncoder.md @@ -65,5 +65,6 @@ conda env remove --name graiexamples ## Steps 1. Dump the training dataset. -1. Train a simple UNet like auto encoder model on the training dataset. -1. Train a simple StyleGAN like auto encoder model on the training dataset. +1. Train a simple auto encoder model. +1. Train a UNet like auto encoder model. +1. Train a StyleGAN like auto encoder model. diff --git a/Docs/Examples/VisionTransformer.md b/Docs/Examples/VisionTransformer.md index 66056c2a..6dfdf405 100644 --- a/Docs/Examples/VisionTransformer.md +++ b/Docs/Examples/VisionTransformer.md @@ -85,4 +85,4 @@ conda env remove --name graiexamples ## Steps 1. Dump the training dataset. -1. Train a simple Vision Transformer model on the training dataset. +1. Train a simple Vision Transformer model. diff --git a/Tests/GrAIExamples/AutoEncoderExample.swift b/Tests/GrAIExamples/AutoEncoderExample.swift index 742a1eaf..e8654bdd 100644 --- a/Tests/GrAIExamples/AutoEncoderExample.swift +++ b/Tests/GrAIExamples/AutoEncoderExample.swift @@ -3,12 +3,13 @@ // GrAIExamples // // Created by Aurélien PEDEN on 23/03/2023. +// Modified by Jean-François Reboud on 21/05/2023. // import XCTest import GrAIdient -/// Test that we can train a simple Auto Encoder model on the CIFAR dataset. +/// Train a simple Auto Encoder model on the CIFAR dataset. final class AutoEncoderExample: XCTestCase { /// Directory to dump outputs from the tests. @@ -17,14 +18,6 @@ final class AutoEncoderExample: XCTestCase /// Batch size of data. let _batchSize = 16 - /// Size of one image (height and width are the same). - let _size = 32 - - /// Mean of the preprocessing to apply to data. - let _mean: (Float, Float, Float) = (123.675, 116.28, 103.53) - /// Deviation of the preprocessing to apply to data. - let _std: (Float, Float, Float) = (58.395, 57.12, 57.375) - /// Initialize test. override func setUp() { @@ -64,146 +57,87 @@ final class AutoEncoderExample: XCTestCase } /// - /// Build an encoder branch. + /// Build an encoder branch with `nbBlock` blocks of dimension reduction (factor of 2). /// - /// - Parameter params: Contextual parameters linking to the model. - /// - Returns: A tuple of layers at different image resolutions. + /// - Parameters: + /// - size: Size of one image (height and width are the same) after resize. + /// - nbBlocks: Number of reduction blocks. + /// - params: Contextual parameters linking to the model. + /// - Returns: A list of layers at different image resolutions. /// - func _buildEncoder(params: GrAI.Model.Params) - -> (Layer2D, Layer2D, Layer2D, Layer2D, Layer2D) + func _buildEncoder( + size: Int, + nbBlocks: Int, + params: GrAI.Model.Params) -> [Layer2D] { - var layer, layer1, layer2, layer3, layer4: Layer2D + var layer: Layer2D + var layers = [Layer2D]() + layer = Input2D( nbChannels: 3, - width: _size, height: _size, - params: params - ) - - layer = Convolution2D( - layerPrev: layer, size: 3, nbChannels: 8, stride: 2, - activation: ReLU.str, biases: true, bn: false, + width: size, height: size, params: params ) - layer4 = layer - layer = Convolution2D( - layerPrev: layer, size: 3, nbChannels: 8, stride: 2, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - layer3 = layer - - layer = Convolution2D( - layerPrev: layer, size: 3, nbChannels: 8, stride: 2, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - layer2 = layer - - layer = Convolution2D( - layerPrev: layer, size: 3, nbChannels: 8, stride: 2, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - layer1 = layer - - layer = Convolution2D( - layerPrev: layer, size: 3, nbChannels: 8, stride: 2, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - return (layer, layer1, layer2, layer3, layer4) + for _ in 0.. Layer2D { - var (layer, layer1, layer2, layer3, layer4) = layersPrev + var layer: Layer2D = layersPrev.first! + var numLayer = 0 - layer = Deconvolution2D( - layerPrev: layer, size: 2, nbChannels: 8, stride: 2, - activation: nil, biases: true, bn: false, - params: params - ) - layer = Concat2D( - layersPrev: [layer1, layer], - params: params - ) - layer = Convolution2D( - layerPrev: layer, - size: 3, nbChannels: 8, stride: 1, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - - layer = Deconvolution2D( - layerPrev: layer, size: 2, nbChannels: 8, stride: 2, - activation: nil, biases: true, bn: false, - params: params - ) - layer = Concat2D( - layersPrev: [layer2, layer], - params: params - ) - layer = Convolution2D( - layerPrev: layer, - size: 3, nbChannels: 8, stride: 1, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - - layer = Deconvolution2D( - layerPrev: layer, size: 2, nbChannels: 8, stride: 2, - activation: nil, biases: true, bn: false, - params: params - ) - layer = Concat2D( - layersPrev: [layer3, layer], - params: params - ) - layer = Convolution2D( - layerPrev: layer, - size: 3, nbChannels: 8, stride: 1, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - - layer = Deconvolution2D( - layerPrev: layer, size: 2, nbChannels: 8, stride: 2, - activation: nil, biases: true, bn: false, - params: params - ) - layer = Concat2D( - layersPrev: [layer4, layer], - params: params - ) - layer = Convolution2D( - layerPrev: layer, - size: 3, nbChannels: 8, stride: 1, - activation: ReLU.str, biases: true, bn: false, - params: params - ) - - layer = Deconvolution2D( - layerPrev: layer, size: 2, nbChannels: 8, stride: 2, - activation: nil, biases: true, bn: false, - params: params - ) - layer = Convolution2D( - layerPrev: layer, - size: 3, nbChannels: 3, stride: 1, - activation: Sigmoid.str, biases: true, bn: false, - params: params - ) + while numLayer < layersPrev.count + { + layer = Deconvolution2D( + layerPrev: layer, size: 2, nbChannels: 8, stride: 2, + activation: nil, biases: true, bn: false, + params: params + ) + + if numLayer + 1 < layersPrev.count + { + layer = Concat2D( + layersPrev: [layersPrev[numLayer + 1], layer], + params: params + ) + layer = Convolution2D( + layerPrev: layer, + size: 3, nbChannels: 8, stride: 1, + activation: ReLU.str, biases: true, bn: false, + params: params + ) + } + else + { + layer = Convolution2D( + layerPrev: layer, + size: 3, nbChannels: 3, stride: 1, + activation: Sigmoid.str, biases: true, bn: false, + params: params + ) + } + numLayer += 1 + } return layer } @@ -216,22 +150,19 @@ final class AutoEncoderExample: XCTestCase /// - Returns: The last layer of the style branch. /// func _buildStyleMapping( - layersPrev: (Layer2D, Layer2D, Layer2D, Layer2D, Layer2D), + layersPrev: [Layer2D], params: GrAI.Model.Params) -> Layer1D { - let (layer1, layer2, layer3, layer4, layer5) = layersPrev - + var layers = [Layer1D]() + for layerPrev in layersPrev + { + layers.append( + AvgPool2D(layerPrev: layerPrev, params: params) + ) + } var layer: Layer1D = Concat1D( - layersPrev: [ - AvgPool2D(layerPrev: layer1, params: params), - AvgPool2D(layerPrev: layer2, params: params), - AvgPool2D(layerPrev: layer3, params: params), - AvgPool2D(layerPrev: layer4, params: params), - AvgPool2D(layerPrev: layer5, params: params) - ], - params: params + layersPrev: layers, params: params ) - for _ in 0..<8 { layer = FullyConnected( @@ -244,22 +175,25 @@ final class AutoEncoderExample: XCTestCase } /// - /// Build a StyleGAN like decoder branch. + /// Build a StyleGAN like decoder branch with `nbBlock` blocks + /// of dimension augmentation (factor of 2). /// /// - Parameters: + /// - nbBlocks: Number of augmentation blocks. /// - style: The last layer of the style branch. /// - params: Contextual parameters linking to the model. /// - Returns: The last layer of the decoder branch. /// - func _buildStyleDecoder(style: Layer1D, params: GrAI.Model.Params) - -> Layer2D + func _buildStyleDecoder( + nbBlocks: Int, + style: Layer1D, + params: GrAI.Model.Params) -> Layer2D { var layer: Layer2D layer = Constant2D( - nbChannels: 8, height: 4, width: 4, + nbChannels: 8, height: 2, width: 2, params: params ) - layer = AdaIN( layersPrev: [ layer, @@ -289,10 +223,10 @@ final class AutoEncoderExample: XCTestCase params: params ) - for _ in 0..<5 + for _ in 0.. Model + func _buildModel( + modelType: ModelClass, + size: Int, + nbBlocks: Int) -> Model { // Create the context to build a graph of layers where // there is no previous model dependency: layer id starts at 0. let context = ModelContext(name: "AutoEncoder", models: []) let params = GrAI.Model.Params(context: context) - let layersPrev = _buildEncoder(params: params) + let layersPrev = _buildEncoder( + size: size, + nbBlocks: nbBlocks, + params: params + ) - var layer: Layer2D switch modelType { case .Style: - layer = _buildStyleDecoder( + _ = _buildStyleDecoder( + nbBlocks: nbBlocks, style: _buildStyleMapping( layersPrev: layersPrev, params: params @@ -365,145 +300,86 @@ final class AutoEncoderExample: XCTestCase params: params ) case .UNet: - layer = _buildUNetDecoder( + _ = _buildUNetDecoder( layersPrev: layersPrev, params: params ) } - - _ = MSE2D(layerPrev: layer, params: params) return Model(model: context.model, modelsPrev: []) } /// /// Train the model. /// - /// - Parameter model: The model to train. + /// - Parameters: + /// - model: The model to train. + /// - size: Size of one image (height and width are the same) after resize. /// - func _trainModel(model: Model) + func _trainModel(model: Model, size: Int) { - let cifar8 = CIFAR.loadDataset( - datasetPath: _outputDir + "/datasetTrain8", - size: _size + let trainer = try! CIFARAutoEncoderTrainer( + model: model, size: size + ) + trainer.run( + batchSize: _batchSize, + label: 8, + nbEpochs: 5, + keep: 1000 ) - - // Get optimizer parameters for iterating over batch size elements. - let params = _getOptimizerParams(nbLoops: _batchSize) - - cifar8.initSamples(batchSize: _batchSize) - - // Keep a subset of the dataset to have a quicker training. - cifar8.keep(1000) - - // Small trick to force full batches throughout the training: - // this enables us to set the ground truth once and for all. - let nbWholeBatches = - cifar8.nbSamples / cifar8.batchSize * cifar8.batchSize - cifar8.keep(nbWholeBatches) - - // Initialize for training. - model.initialize(params: params, phase: .Training) - - let firstLayer: Input2D = model.layers.first as! Input2D - let lastLayer: MSE2D = model.layers.last as! MSE2D - - let nbEpochs = 5 - for epoch in 0.. var dataset = [UInt8]() for dataFile in 1...5 { - let data = pythonLib.load_CIFAR_data(dataFile, label, size) + let data = pythonLib.load_CIFAR_train(dataFile, label, size) dataset += Array(data)! } @@ -108,4 +108,41 @@ class CIFAR: DataSamplerImpl } return CIFAR(data: dataset, size: size) } + + /// + /// Build an iterator on CIFAR dataset. + /// + /// - Parameters: + /// - train: Train of test dataset. + /// - batchSize: The batch size. + /// - label: The label we want the data associated to. + /// - shuffle: Whether to shuffle indices of data. + /// + /// - Returns: A Python iterator. + /// + static func buildIterator( + train: Bool, + batchSize: Int, + label: Int, + shuffle: Bool) -> PythonObject + { + let pythonLib = Python.import("python_lib") + return pythonLib.iter_CIFAR(train, batchSize, label, shuffle) + } + + /// + /// Load next data from a Python iterator. + /// + /// - Parameter iterator: The Python iterator. + /// + static func getSamples(_ iterator: PythonObject) -> ([Float], Int) + { + let pythonLib = Python.import("python_lib") + let data = pythonLib.next_data_CIFAR(iterator) + + let samples = [Float](data.tuple2.0)! + let batchSize = Int(data.tuple2.1)! + + return (samples, batchSize) + } } diff --git a/Tests/GrAIExamples/Base/CIFARAutoEncoderTrainer.swift b/Tests/GrAIExamples/Base/CIFARAutoEncoderTrainer.swift new file mode 100644 index 00000000..f8a8126a --- /dev/null +++ b/Tests/GrAIExamples/Base/CIFARAutoEncoderTrainer.swift @@ -0,0 +1,358 @@ +// +// CIFARAutoEncoderTrainer.swift +// GrAIExamples +// +// Created by Jean-François Reboud on 21/05/2023. +// + +import Foundation +import GrAIdient + +/// Error occuring when trainer cannot be built. +public enum TrainerError: Error +{ + /// Model size is not coherent. + case Size + /// Model structure is not expected. + case Structural +} + +extension TrainerError: CustomStringConvertible +{ + public var description: String + { + switch self + { + case .Size: + return "Model size is not coherent." + case .Structural: + return "Model first layer should be an Input2D." + } + } +} + +/// Train an auto encoder model on CIFAR dataset. +class CIFARAutoEncoderTrainer +{ + /// Directory to dump outputs from the tests. + let _outputDir = NSTemporaryDirectory() + + /// Size of one image (height and width are the same) in the CIFAR datasset. + let _originalSize = 32 + /// Size of one image (height and width are the same) after resize. + let _size: Int + + /// Mean of the preprocessing to apply to data. + let _mean: (Float, Float, Float) = (123.675, 116.28, 103.53) + /// Deviation of the preprocessing to apply to data. + let _std: (Float, Float, Float) = (58.395, 57.12, 57.375) + + /// Dataset to get the data from. + var _dataset: CIFAR! = nil + /// Final model that is being trained. + var _model: Model! = nil + /// Resizer model. + var _resizer: Model? = nil + /// Base model to train. + let _baseModel: Model + + /// + /// Create the trainer. + /// + /// `size` allows to simulate the fact that the model analyzes a coarse image: the inputs and + /// ground truths are resized to `size` in order to do so. + /// + /// Throw an error if the original model's first layer is not an `Input2D` or the size of the latter + /// is not the size expected by the trainer. + /// + /// - Parameters: + /// - model: The original model (auto encoder structure) to train. + /// - size: Size of one image (height and width are the same). + /// + init(model: Model, size: Int) throws + { + _size = size + + if size > _originalSize || size < 2 + { + throw TrainerError.Size + } + + guard let firstLayer = model.layers.first as? Input2D else + { + throw TrainerError.Structural + } + + let height = firstLayer.height + let width = firstLayer.width + if height != _size || width != _size + { + throw TrainerError.Size + } + + _baseModel = model + } + + /// + /// Create the final model (containing the original one + some additional layers) to train. + /// + /// - Returns: The final model to train. + /// + private func _buildModel() -> Model + { + let context = ModelContext(name: "Final", models: [_baseModel]) + let params = GrAI.Model.Params(context: context) + + _ = MSE2D( + layerPrev: _baseModel.layers.last as! Layer2D, + params: params + ) + + var model = Model(name: "Final") + model.layers = _baseModel.layers + context.model.layers + model = Model(model: model, modelsPrev: []) + + return model + } + + /// + /// Create a resizer. + /// + /// - Returns: The resizer model. + /// + private func _buildResizer() -> Model? + { + if _size != _originalSize + { + let context = ModelContext(name: "Resizer", models: []) + let params = GrAI.Model.Params(context: context) + + var layer: Layer2D = Input2D( + nbChannels: 3, + width: _originalSize, + height: _originalSize, + params: params + ) + layer = ResizeBilinear( + layerPrev: layer, + dimension: _size, + params: params + ) + return Model(model: context.model, modelsPrev: []) + } + else + { + return nil + } + } + + /// + /// Get optimizer parameters for model training. + /// + /// - Parameter nbLoops: Number of steps per epoch. + /// - Returns: The optimizer parameters. + /// + func _getOptimizerParams(nbLoops: Int) -> GrAI.Optimizer.Params + { + var optimizerParams = GrAI.Optimizer.Params() + optimizerParams.nbLoops = nbLoops + + // Simple optimizer scheduler: always the same optimizer during + // the training. + optimizerParams.optimizer = ConstEpochsScheduler( + GrAI.Optimizer.Class.Adam + ) + + // Simple variable scheduler: always the same variable during + // the training. + optimizerParams.variables["alpha"] = ConstEpochsVar( + value: ConstVal(1e-3) + ) + optimizerParams.variables["lambda"] = ConstEpochsVar( + value: ConstVal(1e-6) + ) + + // Other schedulers can be built thanks to `GrAI.Optimizer.Params`. + return optimizerParams + } + + /// + /// Initialize dataset, model and optimizer parameters. + /// + /// - Parameters: + /// - batchSize: The number of samples per batch of data. + /// - label: The class of the CIFAR dataset to use. + /// - keep: The number of elements to keep in the dataset. + /// + func initTrain(batchSize: Int, label: Int, keep: Int? = nil) + { + // Create dataset. + CIFAR.dumpTrain( + datasetPath: _outputDir + "/datasetTrain\(label)", + label: label, + size: _originalSize + ) + + // Load dataset. + _dataset = CIFAR.loadDataset( + datasetPath: _outputDir + "/datasetTrain\(label)", + size: _originalSize + ) + _dataset.initSamples(batchSize: batchSize) + if let nbElems = keep + { + _dataset.keep(nbElems) + } + + // Get optimizer parameters for iterating over batch size elements. + let params = _getOptimizerParams(nbLoops: batchSize) + + // Build model. + _model = _buildModel() + + // Build resizer model. + _resizer = _buildResizer() + + // Initialize for training. + _model.initialize(params: params, phase: .Training) + _resizer?.initKernel() + } + + /// + /// One training step. + /// + /// - Returns: The loss on the last training step. + /// + func step() -> Float + { + let firstLayer: Input2D = _model.layers.first as! Input2D + let lastLayer: MSE2D = _model.layers.last as! MSE2D + + // Get data. + let samples = _dataset.getSamples()! + let batchSize = samples.count + + // Pre processing. + let data = preprocess( + samples, + height: _originalSize, + width: _originalSize, + mean: _mean, + std: _std, + imageFormat: .Neuron + ) + + // Reset gradient validity for backward pass + // and update the batch size. + _model.updateKernel(batchSize: batchSize) + + let dataLayer: Layer2D + // Resize data when `_size` is lower than `_originalSize`. + if let resizer = _resizer + { + let resizerFirstLayer = resizer.layers.first as! Input2D + dataLayer = resizer.layers.last as! Layer2D + + resizer.updateKernel(batchSize: batchSize) + + // Set data. + try! resizerFirstLayer.setDataGPU( + data, + batchSize: batchSize, + format: .Neuron + ) + + // Forward. + try! resizer.forward() + + // Set resized data. + try! firstLayer.setDataGPU(dataLayer.outs, batchSize: batchSize) + } + else + { + // Set data. + try! firstLayer.setDataGPU( + data, + batchSize: batchSize, + format: .Neuron + ) + dataLayer = firstLayer + } + + // Forward. + try! _model.forward() + + // Apply loss derivative: take into account the potential coarse image. + try! lastLayer.lossDerivativeGPU( + dataLayer.outs, + batchSize: batchSize + ) + + // Backward. + try! _model.backward() + + // Update weights. + try! _model.update() + + // Get loss result. + // Note that backward is explicitly + // enabled by `applyGradient` whereas `getLoss` is + // just an indicator. + let loss = try! lastLayer.getLossGPU( + dataLayer.outs, + batchSize: batchSize + ) + + // Update internal step. + // This is not mandatory except if we used another + // optimizer scheduler: see `_getOptimizerParams`. + _model.incStep() + + return loss + } + + /// + /// Run the training on multiple steps and multiple epochs. + /// + /// - Parameters: + /// - batchSize: The number of samples per batch of data. + /// - label: The class of the CIFAR dataset to use. + /// - nbEpochs: The number of epochs for the training to continue. + /// - keep: The number of elements to keep in the dataset. + /// + func run(batchSize: Int, label: Int, nbEpochs: Int, keep: Int? = nil) + { + initTrain( + batchSize: batchSize, + label: label, + keep: keep + ) + + for epoch in 0.. Model + { + let context = ModelContext(name: "SimpleAutoEncoder", curID: 0) + let params = GrAI.Model.Params(context: context) + + var layer: Layer2D + layer = Input2D( + nbChannels: 3, + width: size, + height: size, + params: params + ) + + layer = Convolution2D( + layerPrev: layer, + size: 3, nbChannels: 12, stride: 2, + activation: ReLU.str, biases: true, bn: false, + params: params + ) + layer = Convolution2D( + layerPrev: layer, + size: 3, nbChannels: 24, stride: 2, + activation: ReLU.str, biases: true, bn: false, + params: params + ) + layer = Convolution2D( + layerPrev: layer, + size: 3, nbChannels: 48, stride: 2, + activation: ReLU.str, biases: true, bn: false, + params: params + ) + + layer = Deconvolution2D( + layerPrev: layer, + size: 2, nbChannels: 24, stride: 2, + activation: nil, biases: true, bn: false, + params: params + ) + layer = Deconvolution2D( + layerPrev: layer, + size: 2, nbChannels: 12, stride: 2, + activation: nil, biases: true, bn: false, + params: params + ) + layer = Deconvolution2D( + layerPrev: layer, + size: 2, nbChannels: 3, stride: 2, + activation: Sigmoid.str, biases: true, bn: false, + params: params + ) + + let model = Model(model: context.model, modelsPrev: []) + + // Load weights from `PyTorch`. + let pythonLib = Python.import("python_lib") + let data = pythonLib.load_simple_auto_encoder_weights() + + let weights = [[Float]](data.tuple2.0)! + + // Apply weights on the `GrAIdient` model's layers. + var cur = 0 + for num_layer in 0.. Optional[torch.Tensor]: + """ + Load next data from a CIFAR iterator. + + Parameters + ---------- + iterator + The CIFAR dataset iterator. + + Returns + ------- + torch.Tensor + The images tensor with inner shape: + (batch, channel, height, width). + """ + try: + samples, _ = next(iterator) + except StopIteration: + return None + return samples + + +def next_data_CIFAR(iterator) -> Tuple[List[float], int]: + """ + Load and flatten next data from a CIFAR iterator. + + Parameters + ---------- + iterator + The CIFAR dataset iterator. + + Returns + ------- + List[int] + The list of flatten images with inner shape: + (batch, channel, height, width). + int + The batch size of data. + """ + samples = next_tensor_CIFAR(iterator) + if samples is not None: + return samples.flatten().tolist(), len(samples) + else: + return [], 0 diff --git a/Tests/GrAIExamples/Base/python_lib/model.py b/Tests/GrAIExamples/Base/python_lib/model.py new file mode 100644 index 00000000..f3753138 --- /dev/null +++ b/Tests/GrAIExamples/Base/python_lib/model.py @@ -0,0 +1,72 @@ +import torch + + +class SimpleAutoEncoder(torch.nn.Module): + """ + Simple auto encoder model. + """ + + def __init__(self): + super().__init__() + self.encoder = torch.nn.Sequential( + torch.nn.Conv2d( + 3, 12, + kernel_size=3, stride=2, padding=1, + bias=True + ), + torch.nn.ReLU(), + torch.nn.Conv2d( + 12, 24, + kernel_size=3, stride=2, padding=1, + bias=True + ), + torch.nn.ReLU(), + torch.nn.Conv2d( + 24, 48, + kernel_size=3, stride=2, padding=1, + bias=True + ), + torch.nn.ReLU(), + ) + self.decoder = torch.nn.Sequential( + torch.nn.ConvTranspose2d(48, 24, kernel_size=2, stride=2), + torch.nn.ConvTranspose2d(24, 12, kernel_size=2, stride=2), + torch.nn.ConvTranspose2d(12, 3, kernel_size=2, stride=2), + torch.nn.Sigmoid(), + ) + + self.encoder.apply(self.weight_init) + self.decoder.apply(self.weight_init) + + @staticmethod + def weight_init(module: torch.nn.Module): + """ + Initialize weights and biases. + + Parameters + ---------- + module: torch.nn.Module + The module to initialize. + """ + if isinstance(module, torch.nn.Conv2d) or \ + isinstance(module, torch.nn.ConvTranspose2d) or \ + isinstance(module, torch.nn.Linear): + torch.nn.init.xavier_normal_(module.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Parameters + ---------- + x: torch.Tensor + The input tensor. + + Returns + ------- + _: torch.Tensor + The output tensor. + """ + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/Tests/GrAIExamples/Base/python_lib/trainer.py b/Tests/GrAIExamples/Base/python_lib/trainer.py new file mode 100644 index 00000000..4a91aeca --- /dev/null +++ b/Tests/GrAIExamples/Base/python_lib/trainer.py @@ -0,0 +1,72 @@ +import torch +from typing import Optional + +from python_lib.cifar import ( + iter_CIFAR, + next_tensor_CIFAR, +) +from python_lib.model import SimpleAutoEncoder + + +def train_simple_auto_encoder( + batch_size: int, + label: int +): + """ + Build a simple auto encoder trainer. + + Parameters + ---------- + batch_size: int + The batch size. + label: int + The label we want the data associated to. + + Returns + ------- + A trainer on a simple auto encoder model. + """ + torch.manual_seed(42) + model = SimpleAutoEncoder().cpu() + + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + iter_data = iter_CIFAR( + train=True, + batch_size=batch_size, + label=label, + shuffle=False + ) + + while True: + samples = next_tensor_CIFAR(iter_data) + x = model(samples) + loss = criterion(x, samples) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + yield float(loss.detach().numpy()) + + +def step_simple_auto_encoder(trainer) -> Optional[float]: + """ + Compute next loss from the simple auto encoder trainer. + + Parameters + ---------- + trainer + The auto encoder trainer. + + Returns + ------- + float + The loss computed. + """ + try: + loss = next(trainer) + except StopIteration: + return None + return loss diff --git a/Tests/GrAIExamples/Base/python_lib/weight.py b/Tests/GrAIExamples/Base/python_lib/weight.py new file mode 100644 index 00000000..18698b40 --- /dev/null +++ b/Tests/GrAIExamples/Base/python_lib/weight.py @@ -0,0 +1,96 @@ +import torch +import numpy as np +from typing import List, Tuple + +from python_lib.model import SimpleAutoEncoder + + +def _flatten_weights( + weights: np.ndarray +) -> Tuple[List[float], List[int]]: + """ + Flatten weights and biases. + + Parameters + ---------- + weights: np.ndarray + The weights to flatten. + + Returns + ------- + (_, _): List[float], List[int] + The flattened weights, their shape. + """ + weights_list = weights.flatten().tolist() + dims_list = list(weights.shape) + + return weights_list, dims_list + + +def _extract_and_transpose_weights( + modules: [torch.nn.Module] +) -> Tuple[List[List[float]], List[List[int]]]: + """ + Get weights and biases. + Transpose weights when they come from a + ConvTranspose2d layer. + + Parameters + ---------- + modules: [torch.nn.Module] + The list of modules to get the weights and biases from. + + Returns + ------- + (_, _): List[List[float]], List[List[int]] + The flattened weights, their shape. + """ + layers_weights: List[List[float]] = [] + layers_dims: List[List[int]] = [] + for module in modules: + submodules = list(module.children()) + if len(submodules) > 0: + (weights_list, dims_list) = _extract_and_transpose_weights( + submodules + ) + layers_weights += weights_list + layers_dims += dims_list + + else: + if hasattr(module, "weight"): + if isinstance(module, torch.nn.ConvTranspose2d): + weights = np.transpose( + module.weight.detach().numpy(), (1, 0, 2, 3) + ) + weights_list, dims_list = _flatten_weights(weights) + + else: + weights = module.weight.detach().numpy() + weights_list, dims_list = _flatten_weights(weights) + + layers_weights.append(weights_list) + layers_dims.append(dims_list) + + if hasattr(module, "bias"): + weights = module.bias.detach().numpy() + weights_list, dims_list = _flatten_weights(weights) + + layers_weights.append(weights_list) + layers_dims.append(dims_list) + + return layers_weights, layers_dims + + +def load_simple_auto_encoder_weights( +) -> Tuple[List[List[float]], List[List[int]]]: + """ + Get weights and biases for simple auto encoder model. + + Returns + ------- + (_, _): List[List[float]], List[List[int]] + The flattened weights, their shape. + """ + torch.manual_seed(42) + model = SimpleAutoEncoder() + return _extract_and_transpose_weights(list(model.children())) diff --git a/Tests/GrAIExamples/Base/setup.py b/Tests/GrAIExamples/Base/setup.py index ee5f51d9..ca515733 100644 --- a/Tests/GrAIExamples/Base/setup.py +++ b/Tests/GrAIExamples/Base/setup.py @@ -7,6 +7,8 @@ author='Jean-François Reboud', license='MIT', install_requires=[ + "torch==1.10.1", + "torchvision==0.11.2", "numpy==1.23.1", "opencv-python==4.6.0.66" ], diff --git a/Tests/GrAIExamples/CIFARTests.swift b/Tests/GrAIExamples/CIFARTests.swift index 2a8ea985..5fd7bc9a 100644 --- a/Tests/GrAIExamples/CIFARTests.swift +++ b/Tests/GrAIExamples/CIFARTests.swift @@ -18,6 +18,11 @@ final class CIFARTests: XCTestCase /// Size of one image (height and width are the same). let _size = 32 + /// Mean of the preprocessing to apply to data. + let _mean: (Float, Float, Float) = (123.675, 116.28, 103.53) + /// Deviation of the preprocessing to apply to data. + let _std: (Float, Float, Float) = (58.395, 57.12, 57.375) + /// Initialize test. override func setUp() { @@ -108,7 +113,7 @@ final class CIFARTests: XCTestCase XCTAssert(nbLoops == cifar.nbLoops) } - /// Test4: dump testing dataset and load it.. + /// Test4: dump testing dataset and load it. func test4_DumpTest() { let datasetPath = _outputDir + "/datasetTest" @@ -122,4 +127,66 @@ final class CIFARTests: XCTestCase size: _size ) } + + /// Test5: iterate on CIFAR, preprocess and compare with PyTorch results. + func test5_PreprocessSamples() + { + let cifar = CIFAR.loadDataset( + datasetPath: _outputDir + "/datasetTrain", + size: _size + ) + cifar.initSamples(batchSize: _batchSize) + + let iterator = CIFAR.buildIterator( + train: true, + batchSize: _batchSize, + label: 0, + shuffle: false + ) + + var nbLoops = 0 + var lastLoop = false + var batchSize = 0 + var samples2 = [Float]() + + while let samples1 = cifar.getSamples() + { + (samples2, batchSize) = CIFAR.getSamples(iterator) + + XCTAssert(!lastLoop) + if samples1.count != _batchSize + { + lastLoop = true + } + else + { + XCTAssert(samples1.count == _batchSize) + XCTAssert(batchSize == _batchSize) + } + + // Pre processing. + let data: [Float] = preprocess( + samples1, + height: _size, + width: _size, + mean: _mean, + std: _std, + imageFormat: .Neuron + ) + + for (elem1, elem2) in zip(data, samples2) + { + XCTAssertEqual(elem1, elem2, accuracy: 0.0001) + } + nbLoops += 1 + } + + print("Number of loops per epoch: " + String(nbLoops)) + XCTAssert(nbLoops == cifar.nbLoops) + XCTAssert(cifar.getSamples() == nil) + + (samples2, batchSize) = CIFAR.getSamples(iterator) + XCTAssert(samples2.count == 0) + XCTAssert(batchSize == 0) + } } diff --git a/Tests/GrAIExamples/TransformerExample.swift b/Tests/GrAIExamples/TransformerExample.swift index b4d7b44e..57b2ff84 100644 --- a/Tests/GrAIExamples/TransformerExample.swift +++ b/Tests/GrAIExamples/TransformerExample.swift @@ -8,7 +8,7 @@ import XCTest import GrAIdient -/// Test that we can train a simple Vision Transformer model on the CIFAR dataset. +/// Train a simple Vision Transformer model on the CIFAR dataset. final class TransformerExample: XCTestCase { /// Directory to dump outputs from the tests. diff --git a/Tests/GrAIExamples/VGGExample.swift b/Tests/GrAIExamples/VGGExample.swift index 19ef7c7b..a6221ed6 100644 --- a/Tests/GrAIExamples/VGGExample.swift +++ b/Tests/GrAIExamples/VGGExample.swift @@ -8,7 +8,7 @@ import XCTest import GrAIdient -/// Test that we can train a simple VGG model on the CIFAR dataset. +/// Train a simple VGG model on the CIFAR dataset. final class VGGExample: XCTestCase { /// Directory to dump outputs from the tests. diff --git a/Tests/GrAITorchTests/GrAITorchTests.swift b/Tests/GrAITorchTests/GrAITorchTests.swift index aac790d7..49519b23 100644 --- a/Tests/GrAITorchTests/GrAITorchTests.swift +++ b/Tests/GrAITorchTests/GrAITorchTests.swift @@ -8,7 +8,7 @@ import XCTest import GrAIdient -/// Compare models created by GrAIdient and PyTorch. +/// Compare models created in GrAIdient and PyTorch. final class GrAITorchTests: XCTestCase { /// Size of one image (height and width are the same).