From 3b32c95a3d326862532178693344317963b6c7f3 Mon Sep 17 00:00:00 2001 From: Spnetic-5 Date: Wed, 9 Aug 2023 22:51:52 -0400 Subject: [PATCH 1/9] Added batch norm layer modules --- src/nf/nf_batch_norm_layer.f90 | 108 +++++++++++++++++++++++ src/nf/nf_batch_norm_layer_submodule.f90 | 106 ++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 src/nf/nf_batch_norm_layer.f90 create mode 100644 src/nf/nf_batch_norm_layer_submodule.f90 diff --git a/src/nf/nf_batch_norm_layer.f90 b/src/nf/nf_batch_norm_layer.f90 new file mode 100644 index 00000000..df7623de --- /dev/null +++ b/src/nf/nf_batch_norm_layer.f90 @@ -0,0 +1,108 @@ +module nf_batch_norm_layer + + !! This module provides a batch normalization `batch_norm_layer` type. + + use nf_base_layer, only: base_layer + implicit none + + private + public :: batch_norm_layer + + type, extends(base_layer) :: batch_norm_layer + + integer :: size + real, allocatable :: gamma(:) + real, allocatable :: beta(:) + real, allocatable :: running_mean(:) + real, allocatable :: running_var(:) + real, allocatable :: input(:,:) + real, allocatable :: output(:,:) + real, allocatable :: gamma_grad(:) + real, allocatable :: beta_grad(:) + real, allocatable :: input_grad(:,:) + + contains + + procedure :: forward + procedure :: backward + procedure :: get_gradients + procedure :: get_num_params + procedure :: get_params + procedure :: init + procedure :: set_params + + end type batch_norm_layer + + interface batch_norm_layer + pure module function batch_norm_layer_cons(size) result(res) + !! `batch_norm_layer` constructor function + integer, intent(in) :: size + type(batch_norm_layer) :: res + end function batch_norm_layer_cons + end interface batch_norm_layer + + interface + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(batch_norm_layer), intent(in out) :: self + !! A `batch_norm_layer` instance + integer, intent(in) :: input_shape(:) + !! Input layer dimensions + end subroutine init + + pure module subroutine forward(self, input) + !! Apply a forward pass on the `batch_normalization` layer. + class(batch_norm_layer), intent(in out) :: self + !! A `batch_norm_layer` instance + real, intent(in) :: input(:,:) + !! Input data + end subroutine forward + + pure module subroutine backward(self, input, gradient) + !! Apply a backward pass on the `batch_normalization` layer. + class(batch_norm_layer), intent(in out) :: self + !! A `batch_norm_layer` instance + real, intent(in) :: input(:,:) + !! Input data (previous layer) + real, intent(in) :: gradient(:,:) + !! Gradient (next layer) + end subroutine backward + + pure module function get_num_params(self) result(num_params) + !! Get the number of parameters in the layer. + class(batch_norm_layer), intent(in) :: self + !! A `batch_norm_layer` instance + integer :: num_params + !! Number of parameters + end function get_num_params + + pure module function get_params(self) result(params) + !! Return the parameters (gamma, beta, running_mean, running_var) of this layer. + class(batch_norm_layer), intent(in) :: self + !! A `batch_norm_layer` instance + real, allocatable :: params(:) + !! Parameters to get + end function get_params + + pure module function get_gradients(self) result(gradients) + !! Return the gradients of this layer. + class(batch_norm_layer), intent(in) :: self + !! A `batch_norm_layer` instance + real, allocatable :: gradients(:) + !! Gradients to get + end function get_gradients + + module subroutine set_params(self, params) + !! Set the parameters of the layer. + class(batch_norm_layer), intent(in out) :: self + !! A `batch_norm_layer` instance + real, intent(in) :: params(:) + !! Parameters to set + end subroutine set_params + + end interface + +end module nf_batch_norm_layer diff --git a/src/nf/nf_batch_norm_layer_submodule.f90 b/src/nf/nf_batch_norm_layer_submodule.f90 new file mode 100644 index 00000000..ca7e5a0e --- /dev/null +++ b/src/nf/nf_batch_norm_layer_submodule.f90 @@ -0,0 +1,106 @@ +submodule(nf_batch_norm_layer) nf_batch_norm_layer_submodule + + use nf_base_layer, only: base_layer + implicit none + +contains + + pure module function batch_norm_layer_cons(size) result(res) + implicit none + integer, intent(in) :: size + type(batch_norm_layer) :: res + + res % size = size + allocate(res % gamma(size), source=1.0) + allocate(res % beta(size)) + allocate(res % running_mean(size), source=0.0) + allocate(res % running_var(size), source=1.0) + allocate(res % input(size, size)) + allocate(res % output(size, size)) + allocate(res % gamma_grad(size)) + allocate(res % beta_grad(size)) + allocate(res % input_grad(size, size)) + + end function batch_norm_layer_cons + + module subroutine init(self, input_shape) + implicit none + class(batch_norm_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input = 0 + self % output = 0 + + ! Initialize gamma, beta, running_mean, and running_var + self % gamma = 1.0 + self % beta = 0.0 + self % running_mean = 0.0 + self % running_var = 1.0 + + end subroutine init + + pure module subroutine forward(self, input) + implicit none + class(batch_norm_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, allocatable :: normalized_input(:,:) + + ! Store input for backward pass + self % input = input + + ! Calculate the normalized input + normalized_input = (input - reshape(self % running_mean, shape(input, 1))) * & + reshape(self % gamma, shape(input, 1)) / & + sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8) + + ! Batch normalization forward pass + self % output = normalized_input + reshape(self % beta, shape(input, 1)) + + ! Deallocate temporary array + deallocate(normalized_input) + + end subroutine forward + + pure module subroutine backward(self, input, gradient) + implicit none + class(batch_norm_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + + ! Calculate gradients for gamma, beta + self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) / & + sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8), dim=2) + self % beta_grad = sum(gradient, dim=2) + + ! Calculate gradients for input + self % input_grad = gradient * reshape(self % gamma, shape(gradient)) / & + sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8) + + end subroutine backward + + pure module function get_num_params(self) result(num_params) + class(batch_norm_layer), intent(in) :: self + integer :: num_params + num_params = 2 * self % size + end function get_num_params + + pure module function get_params(self) result(params) + class(batch_norm_layer), intent(in) :: self + real, allocatable :: params(:) + params = [self % gamma, self % beta] + end function get_params + + pure module function get_gradients(self) result(gradients) + class(batch_norm_layer), intent(in) :: self + real, allocatable :: gradients(:) + gradients = [self % gamma_grad, self % beta_grad] + end function get_gradients + + module subroutine set_params(self, params) + class(batch_norm_layer), intent(in out) :: self + real, intent(in) :: params(:) + self % gamma = params(1:self % size) + self % beta = params(self % size+1:2*self % size) + end subroutine set_params + +end submodule nf_batch_norm_layer_submodule From 1a0ff08f173bd347f482dff7c35a05f69ca6de0c Mon Sep 17 00:00:00 2001 From: Spnetic-5 Date: Wed, 9 Aug 2023 22:57:02 -0400 Subject: [PATCH 2/9] Update cmake --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index bd48e993..f9f67edc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ add_library(neural src/nf.f90 src/nf/nf_activation.f90 src/nf/nf_base_layer.f90 + src/nf/nf_batch_norm_layer.f90 + src/nf/nf_batch_norm_layer_submodule.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 src/nf/nf_datasets.f90 From e4d8e1e9c156e57a19add5914d9002ad32cbdabd Mon Sep 17 00:00:00 2001 From: Spnetic-5 Date: Thu, 10 Aug 2023 19:00:23 -0400 Subject: [PATCH 3/9] Update forward & backward formulae --- src/nf/nf_batch_norm_layer_submodule.f90 | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/nf/nf_batch_norm_layer_submodule.f90 b/src/nf/nf_batch_norm_layer_submodule.f90 index ca7e5a0e..7ebc3969 100644 --- a/src/nf/nf_batch_norm_layer_submodule.f90 +++ b/src/nf/nf_batch_norm_layer_submodule.f90 @@ -49,12 +49,12 @@ pure module subroutine forward(self, input) self % input = input ! Calculate the normalized input - normalized_input = (input - reshape(self % running_mean, shape(input, 1))) * & - reshape(self % gamma, shape(input, 1)) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8) + normalized_input = (input - reshape(self % running_mean, shape(input, 1))) / & + sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8) ! Batch normalization forward pass - self % output = normalized_input + reshape(self % beta, shape(input, 1)) + self % output = (reshape(self % gamma, shape(input, 1)) * & + normalized_input) + reshape(self % beta, shape(input, 1)) ! Deallocate temporary array deallocate(normalized_input) @@ -69,12 +69,12 @@ pure module subroutine backward(self, input, gradient) ! Calculate gradients for gamma, beta self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8), dim=2) + sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8), dim=2) self % beta_grad = sum(gradient, dim=2) ! Calculate gradients for input - self % input_grad = gradient * reshape(self % gamma, shape(gradient)) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1.0e-8) + self % input_grad = gradient * reshape(self % gamma, shape(input, 1)) / & + sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8) end subroutine backward From 42335f18a7196efc5bf608bed11a01d03279d147 Mon Sep 17 00:00:00 2001 From: Spnetic-5 Date: Sat, 19 Aug 2023 23:03:13 -0400 Subject: [PATCH 4/9] Added draft test for batch norm layer --- src/nf.f90 | 2 +- src/nf/nf_batch_norm_layer.f90 | 6 +- src/nf/nf_batch_norm_layer_submodule.f90 | 31 ++++---- src/nf/nf_layer_constructors.f90 | 21 ++++- src/nf/nf_layer_constructors_submodule.f90 | 8 ++ src/nf/nf_layer_submodule.f90 | 1 + src/nf/nf_network_submodule.f90 | 2 +- test/CMakeLists.txt | 1 + test/test_batch_norm_layer.f90 | 89 ++++++++++++++++++++++ 9 files changed, 139 insertions(+), 22 deletions(-) create mode 100644 test/test_batch_norm_layer.f90 diff --git a/src/nf.f90 b/src/nf.f90 index eb2a903a..3c10f476 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_network, only: network use nf_optimizers, only: sgd, rmsprop, adam, adagrad use nf_activation, only: activation_function, elu, exponential, & diff --git a/src/nf/nf_batch_norm_layer.f90 b/src/nf/nf_batch_norm_layer.f90 index df7623de..cc34572f 100644 --- a/src/nf/nf_batch_norm_layer.f90 +++ b/src/nf/nf_batch_norm_layer.f90 @@ -10,7 +10,7 @@ module nf_batch_norm_layer type, extends(base_layer) :: batch_norm_layer - integer :: size + integer :: num_features real, allocatable :: gamma(:) real, allocatable :: beta(:) real, allocatable :: running_mean(:) @@ -34,9 +34,9 @@ module nf_batch_norm_layer end type batch_norm_layer interface batch_norm_layer - pure module function batch_norm_layer_cons(size) result(res) + pure module function batch_norm_layer_cons(num_features) result(res) !! `batch_norm_layer` constructor function - integer, intent(in) :: size + integer, intent(in) :: num_features type(batch_norm_layer) :: res end function batch_norm_layer_cons end interface batch_norm_layer diff --git a/src/nf/nf_batch_norm_layer_submodule.f90 b/src/nf/nf_batch_norm_layer_submodule.f90 index 7ebc3969..0df434a5 100644 --- a/src/nf/nf_batch_norm_layer_submodule.f90 +++ b/src/nf/nf_batch_norm_layer_submodule.f90 @@ -1,25 +1,24 @@ submodule(nf_batch_norm_layer) nf_batch_norm_layer_submodule - use nf_base_layer, only: base_layer implicit none contains - pure module function batch_norm_layer_cons(size) result(res) + pure module function batch_norm_layer_cons(num_features) result(res) implicit none - integer, intent(in) :: size + integer, intent(in) :: num_features type(batch_norm_layer) :: res - res % size = size - allocate(res % gamma(size), source=1.0) - allocate(res % beta(size)) - allocate(res % running_mean(size), source=0.0) - allocate(res % running_var(size), source=1.0) - allocate(res % input(size, size)) - allocate(res % output(size, size)) - allocate(res % gamma_grad(size)) - allocate(res % beta_grad(size)) - allocate(res % input_grad(size, size)) + res % num_features = num_features + allocate(res % gamma(num_features), source=1.0) + allocate(res % beta(num_features)) + allocate(res % running_mean(num_features), source=0.0) + allocate(res % running_var(num_features), source=1.0) + allocate(res % input(num_features, num_features)) + allocate(res % output(num_features, num_features)) + allocate(res % gamma_grad(num_features)) + allocate(res % beta_grad(num_features)) + allocate(res % input_grad(num_features, num_features)) end function batch_norm_layer_cons @@ -81,7 +80,7 @@ end subroutine backward pure module function get_num_params(self) result(num_params) class(batch_norm_layer), intent(in) :: self integer :: num_params - num_params = 2 * self % size + num_params = 2 * self % num_features end function get_num_params pure module function get_params(self) result(params) @@ -99,8 +98,8 @@ end function get_gradients module subroutine set_params(self, params) class(batch_norm_layer), intent(in out) :: self real, intent(in) :: params(:) - self % gamma = params(1:self % size) - self % beta = params(self % size+1:2*self % size) + self % gamma = params(1:self % num_features) + self % beta = params(self % num_features+1:2*self % num_features) end subroutine set_params end submodule nf_batch_norm_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index ce9a7244..764cf9fe 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape interface input @@ -106,6 +106,25 @@ pure module function flatten() result(res) !! Resulting layer instance end function flatten + pure module function batch_norm(num_features) result(res) + !! Batch normalization layer constructor. + !! + !! This layer is for adding batch normalization to the network. + !! A batch normalization layer can be used after conv2d or dense layers. + !! + !! Example: + !! + !! ``` + !! use nf, only :: batch_norm, layer + !! type(layer) :: batch_norm_layer + !! batch_norm_layer = batch_norm(num_features = 64) + !! ``` + integer, intent(in) :: num_features + !! Number of features in the Layer + type(layer) :: res + !! Resulting layer instance + end function batch_norm + pure module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 002a83ba..c5661644 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer_constructors) nf_layer_constructors_submodule use nf_layer, only: layer + use nf_batch_norm_layer, only: batch_norm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer @@ -14,6 +15,13 @@ contains + pure module function batch_norm(num_features) result(res) + integer, intent(in) :: num_features + type(layer) :: res + res % name = 'batch_norm' + allocate(res % p, source=batch_norm_layer(num_features)) + end function batch_norm + pure module function conv2d(filters, kernel_size, activation) result(res) integer, intent(in) :: filters integer, intent(in) :: kernel_size diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 07467643..1625a519 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer) nf_layer_submodule use iso_fortran_env, only: stderr => error_unit + use nf_batch_norm_layer, only: batch_norm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 5bafb7cf..388e4575 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -10,7 +10,7 @@ use nf_io_hdf5, only: get_hdf5_dataset use nf_keras, only: get_keras_h5_layers, keras_layer use nf_layer, only: layer - use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape + use nf_layer_constructors, only: batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic_derivative use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 26646ec1..b8f7a091 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,6 +16,7 @@ foreach(execid cnn_from_keras conv2d_network optimizers + batch_norm_layer ) add_executable(test_${execid} test_${execid}.f90) target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS}) diff --git a/test/test_batch_norm_layer.f90 b/test/test_batch_norm_layer.f90 new file mode 100644 index 00000000..59b81245 --- /dev/null +++ b/test/test_batch_norm_layer.f90 @@ -0,0 +1,89 @@ +program test_batch_norm_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: batch_norm, input, layer + use nf_input3d_layer, only: input3d_layer + use nf_batch_norm_layer, only: batch_norm_layer + + implicit none + + type(layer) :: bn_layer, input_layer + integer, parameter :: num_features = 64 + real, allocatable :: sample_input(:,:) + real, allocatable :: output(:,:) + real, allocatable :: gradient(:,:) + integer, parameter :: input_shape(1) = [num_features] + real, allocatable :: gamma_grad(:), beta_grad(:) + real, parameter :: tolerance = 1e-7 + logical :: ok = .true. + + bn_layer = batch_norm(num_features) + + if (.not. bn_layer % name == 'batch_norm') then + ok = .false. + write(stderr, '(a)') 'batch_norm layer has its name set correctly.. failed' + end if + + if (bn_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'batch_norm layer should not be marked as initialized yet.. failed' + end if + + input_layer = input(input_shape) + call bn_layer % init(input_layer) + + if (.not. bn_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'batch_norm layer should now be marked as initialized.. failed' + end if + + if (.not. all(bn_layer % input_layer_shape == [num_features])) then + ok = .false. + write(stderr, '(a)') 'batch_norm layer input layer shape should be correct.. failed' + end if + + ! Initialize sample input and gradient + allocate(sample_input(num_features, 1)) + allocate(gradient(num_features, 1)) + sample_input = 1.0 + gradient = 2.0 + + ! Set input for the input layer + select type(this_layer => input_layer % p); type is(input3d_layer) + call this_layer % set(sample_input) + end select + + ! Initialize the batch normalization layer + bn_layer = batch_norm(num_features) + call bn_layer % init(input_layer) + + ! Perform forward and backward passes + call bn_layer % forward(input_layer) + call bn_layer % backward(input_layer, gradient) + + ! Retrieve output and check normalization + call bn_layer % get_output(output) + if (.not. all(abs(output - sample_input) < tolerance)) then + ok = .false. + write(stderr, '(a)') 'batch_norm layer output should be close to input.. failed' + end if + + ! Retrieve gamma and beta gradients + allocate(gamma_grad(num_features)) + allocate(beta_grad(num_features)) + call bn_layer % get_gradients(gamma_grad, beta_grad) + + if (.not. all(beta_grad == sum(gradient))) then + ok = .false. + write(stderr, '(a)') 'batch_norm layer beta gradients are incorrect.. failed' + end if + + ! Report test results + if (ok) then + print '(a)', 'test_batch_norm_layer: All tests passed.' + else + write(stderr, '(a)') 'test_batch_norm_layer: One or more tests failed.' + stop 1 + end if + +end program test_batch_norm_layer From de67a889b1bc21ca7b73c8d7a3544b6085040cfa Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 24 Aug 2023 11:35:04 -0400 Subject: [PATCH 5/9] rename batch_norm -> batchnorm --- CMakeLists.txt | 4 +- src/nf.f90 | 2 +- ..._norm_layer.f90 => nf_batchnorm_layer.f90} | 56 +++++++++---------- ...e.f90 => nf_batchnorm_layer_submodule.f90} | 24 ++++---- src/nf/nf_layer_constructors.f90 | 12 ++-- src/nf/nf_layer_constructors_submodule.f90 | 10 ++-- src/nf/nf_layer_submodule.f90 | 2 +- src/nf/nf_network_submodule.f90 | 2 +- test/CMakeLists.txt | 2 +- ...orm_layer.f90 => test_batchnorm_layer.f90} | 30 +++++----- 10 files changed, 72 insertions(+), 72 deletions(-) rename src/nf/{nf_batch_norm_layer.f90 => nf_batchnorm_layer.f90} (63%) rename src/nf/{nf_batch_norm_layer_submodule.f90 => nf_batchnorm_layer_submodule.f90} (83%) rename test/{test_batch_norm_layer.f90 => test_batchnorm_layer.f90} (66%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f9f67edc..0fd1be6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,8 +26,8 @@ add_library(neural src/nf.f90 src/nf/nf_activation.f90 src/nf/nf_base_layer.f90 - src/nf/nf_batch_norm_layer.f90 - src/nf/nf_batch_norm_layer_submodule.f90 + src/nf/nf_batchnorm_layer.f90 + src/nf/nf_batchnorm_layer_submodule.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 src/nf/nf_datasets.f90 diff --git a/src/nf.f90 b/src/nf.f90 index 3c10f476..b6b90b4f 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape + batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_network, only: network use nf_optimizers, only: sgd, rmsprop, adam, adagrad use nf_activation, only: activation_function, elu, exponential, & diff --git a/src/nf/nf_batch_norm_layer.f90 b/src/nf/nf_batchnorm_layer.f90 similarity index 63% rename from src/nf/nf_batch_norm_layer.f90 rename to src/nf/nf_batchnorm_layer.f90 index cc34572f..aa2826af 100644 --- a/src/nf/nf_batch_norm_layer.f90 +++ b/src/nf/nf_batchnorm_layer.f90 @@ -1,14 +1,14 @@ -module nf_batch_norm_layer +module nf_batchnorm_layer - !! This module provides a batch normalization `batch_norm_layer` type. + !! This module provides a batch normalization `batchnorm_layer` type. use nf_base_layer, only: base_layer implicit none private - public :: batch_norm_layer + public :: batchnorm_layer - type, extends(base_layer) :: batch_norm_layer + type, extends(base_layer) :: batchnorm_layer integer :: num_features real, allocatable :: gamma(:) @@ -31,15 +31,15 @@ module nf_batch_norm_layer procedure :: init procedure :: set_params - end type batch_norm_layer + end type batchnorm_layer - interface batch_norm_layer - pure module function batch_norm_layer_cons(num_features) result(res) - !! `batch_norm_layer` constructor function + interface batchnorm_layer + pure module function batchnorm_layer_cons(num_features) result(res) + !! `batchnorm_layer` constructor function integer, intent(in) :: num_features - type(batch_norm_layer) :: res - end function batch_norm_layer_cons - end interface batch_norm_layer + type(batchnorm_layer) :: res + end function batchnorm_layer_cons + end interface batchnorm_layer interface @@ -47,24 +47,24 @@ module subroutine init(self, input_shape) !! Initialize the layer data structures. !! !! This is a deferred procedure from the `base_layer` abstract type. - class(batch_norm_layer), intent(in out) :: self - !! A `batch_norm_layer` instance + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance integer, intent(in) :: input_shape(:) !! Input layer dimensions end subroutine init pure module subroutine forward(self, input) - !! Apply a forward pass on the `batch_normalization` layer. - class(batch_norm_layer), intent(in out) :: self - !! A `batch_norm_layer` instance + !! Apply a forward pass on the `batchnorm_layer`. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance real, intent(in) :: input(:,:) !! Input data end subroutine forward pure module subroutine backward(self, input, gradient) - !! Apply a backward pass on the `batch_normalization` layer. - class(batch_norm_layer), intent(in out) :: self - !! A `batch_norm_layer` instance + !! Apply a backward pass on the `batchnorm_layer`. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance real, intent(in) :: input(:,:) !! Input data (previous layer) real, intent(in) :: gradient(:,:) @@ -73,36 +73,36 @@ end subroutine backward pure module function get_num_params(self) result(num_params) !! Get the number of parameters in the layer. - class(batch_norm_layer), intent(in) :: self - !! A `batch_norm_layer` instance + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance integer :: num_params !! Number of parameters end function get_num_params pure module function get_params(self) result(params) !! Return the parameters (gamma, beta, running_mean, running_var) of this layer. - class(batch_norm_layer), intent(in) :: self - !! A `batch_norm_layer` instance + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance real, allocatable :: params(:) !! Parameters to get end function get_params pure module function get_gradients(self) result(gradients) !! Return the gradients of this layer. - class(batch_norm_layer), intent(in) :: self - !! A `batch_norm_layer` instance + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance real, allocatable :: gradients(:) !! Gradients to get end function get_gradients module subroutine set_params(self, params) !! Set the parameters of the layer. - class(batch_norm_layer), intent(in out) :: self - !! A `batch_norm_layer` instance + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance real, intent(in) :: params(:) !! Parameters to set end subroutine set_params end interface -end module nf_batch_norm_layer +end module nf_batchnorm_layer diff --git a/src/nf/nf_batch_norm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90 similarity index 83% rename from src/nf/nf_batch_norm_layer_submodule.f90 rename to src/nf/nf_batchnorm_layer_submodule.f90 index 0df434a5..63a5398f 100644 --- a/src/nf/nf_batch_norm_layer_submodule.f90 +++ b/src/nf/nf_batchnorm_layer_submodule.f90 @@ -1,13 +1,13 @@ -submodule(nf_batch_norm_layer) nf_batch_norm_layer_submodule +submodule(nf_batchnorm_layer) nf_batchnorm_layer_submodule implicit none contains - pure module function batch_norm_layer_cons(num_features) result(res) + pure module function batchnorm_layer_cons(num_features) result(res) implicit none integer, intent(in) :: num_features - type(batch_norm_layer) :: res + type(batchnorm_layer) :: res res % num_features = num_features allocate(res % gamma(num_features), source=1.0) @@ -20,11 +20,11 @@ pure module function batch_norm_layer_cons(num_features) result(res) allocate(res % beta_grad(num_features)) allocate(res % input_grad(num_features, num_features)) - end function batch_norm_layer_cons + end function batchnorm_layer_cons module subroutine init(self, input_shape) implicit none - class(batch_norm_layer), intent(in out) :: self + class(batchnorm_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) self % input = 0 @@ -40,7 +40,7 @@ end subroutine init pure module subroutine forward(self, input) implicit none - class(batch_norm_layer), intent(in out) :: self + class(batchnorm_layer), intent(in out) :: self real, intent(in) :: input(:,:) real, allocatable :: normalized_input(:,:) @@ -62,7 +62,7 @@ end subroutine forward pure module subroutine backward(self, input, gradient) implicit none - class(batch_norm_layer), intent(in out) :: self + class(batchnorm_layer), intent(in out) :: self real, intent(in) :: input(:,:) real, intent(in) :: gradient(:,:) @@ -78,28 +78,28 @@ pure module subroutine backward(self, input, gradient) end subroutine backward pure module function get_num_params(self) result(num_params) - class(batch_norm_layer), intent(in) :: self + class(batchnorm_layer), intent(in) :: self integer :: num_params num_params = 2 * self % num_features end function get_num_params pure module function get_params(self) result(params) - class(batch_norm_layer), intent(in) :: self + class(batchnorm_layer), intent(in) :: self real, allocatable :: params(:) params = [self % gamma, self % beta] end function get_params pure module function get_gradients(self) result(gradients) - class(batch_norm_layer), intent(in) :: self + class(batchnorm_layer), intent(in) :: self real, allocatable :: gradients(:) gradients = [self % gamma_grad, self % beta_grad] end function get_gradients module subroutine set_params(self, params) - class(batch_norm_layer), intent(in out) :: self + class(batchnorm_layer), intent(in out) :: self real, intent(in) :: params(:) self % gamma = params(1:self % num_features) self % beta = params(self % num_features+1:2*self % num_features) end subroutine set_params -end submodule nf_batch_norm_layer_submodule +end submodule nf_batchnorm_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 764cf9fe..b036f1bd 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape + public :: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape interface input @@ -106,7 +106,7 @@ pure module function flatten() result(res) !! Resulting layer instance end function flatten - pure module function batch_norm(num_features) result(res) + pure module function batchnorm(num_features) result(res) !! Batch normalization layer constructor. !! !! This layer is for adding batch normalization to the network. @@ -115,15 +115,15 @@ pure module function batch_norm(num_features) result(res) !! Example: !! !! ``` - !! use nf, only :: batch_norm, layer - !! type(layer) :: batch_norm_layer - !! batch_norm_layer = batch_norm(num_features = 64) + !! use nf, only :: batchnorm, layer + !! type(layer) :: batchnorm_layer + !! batchnorm_layer = batchnorm(num_features = 64) !! ``` integer, intent(in) :: num_features !! Number of features in the Layer type(layer) :: res !! Resulting layer instance - end function batch_norm + end function batchnorm pure module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index c5661644..914df2f7 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -1,7 +1,7 @@ submodule(nf_layer_constructors) nf_layer_constructors_submodule use nf_layer, only: layer - use nf_batch_norm_layer, only: batch_norm_layer + use nf_batchnorm_layer, only: batchnorm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer @@ -15,12 +15,12 @@ contains - pure module function batch_norm(num_features) result(res) + pure module function batchnorm(num_features) result(res) integer, intent(in) :: num_features type(layer) :: res - res % name = 'batch_norm' - allocate(res % p, source=batch_norm_layer(num_features)) - end function batch_norm + res % name = 'batchnorm' + allocate(res % p, source=batchnorm_layer(num_features)) + end function batchnorm pure module function conv2d(filters, kernel_size, activation) result(res) integer, intent(in) :: filters diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 1625a519..94d9d17e 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -1,7 +1,7 @@ submodule(nf_layer) nf_layer_submodule use iso_fortran_env, only: stderr => error_unit - use nf_batch_norm_layer, only: batch_norm_layer + use nf_batchnorm_layer, only: batchnorm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 388e4575..ecff74d2 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -10,7 +10,7 @@ use nf_io_hdf5, only: get_hdf5_dataset use nf_keras, only: get_keras_h5_layers, keras_layer use nf_layer, only: layer - use nf_layer_constructors, only: batch_norm, conv2d, dense, flatten, input, maxpool2d, reshape + use nf_layer_constructors, only: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic_derivative use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b8f7a091..b4ee8202 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,7 +16,7 @@ foreach(execid cnn_from_keras conv2d_network optimizers - batch_norm_layer + batchnorm_layer ) add_executable(test_${execid} test_${execid}.f90) target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS}) diff --git a/test/test_batch_norm_layer.f90 b/test/test_batchnorm_layer.f90 similarity index 66% rename from test/test_batch_norm_layer.f90 rename to test/test_batchnorm_layer.f90 index 59b81245..253ad6f7 100644 --- a/test/test_batch_norm_layer.f90 +++ b/test/test_batchnorm_layer.f90 @@ -1,9 +1,9 @@ -program test_batch_norm_layer +program test_batchnorm_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: batch_norm, input, layer + use nf, only: batchnorm, input, layer use nf_input3d_layer, only: input3d_layer - use nf_batch_norm_layer, only: batch_norm_layer + use nf_batchnorm_layer, only: batchnorm_layer implicit none @@ -17,16 +17,16 @@ program test_batch_norm_layer real, parameter :: tolerance = 1e-7 logical :: ok = .true. - bn_layer = batch_norm(num_features) + bn_layer = batchnorm(num_features) - if (.not. bn_layer % name == 'batch_norm') then + if (.not. bn_layer % name == 'batchnorm') then ok = .false. - write(stderr, '(a)') 'batch_norm layer has its name set correctly.. failed' + write(stderr, '(a)') 'batchnorm layer has its name set correctly.. failed' end if if (bn_layer % initialized) then ok = .false. - write(stderr, '(a)') 'batch_norm layer should not be marked as initialized yet.. failed' + write(stderr, '(a)') 'batchnorm layer should not be marked as initialized yet.. failed' end if input_layer = input(input_shape) @@ -34,12 +34,12 @@ program test_batch_norm_layer if (.not. bn_layer % initialized) then ok = .false. - write(stderr, '(a)') 'batch_norm layer should now be marked as initialized.. failed' + write(stderr, '(a)') 'batchnorm layer should now be marked as initialized.. failed' end if if (.not. all(bn_layer % input_layer_shape == [num_features])) then ok = .false. - write(stderr, '(a)') 'batch_norm layer input layer shape should be correct.. failed' + write(stderr, '(a)') 'batchnorm layer input layer shape should be correct.. failed' end if ! Initialize sample input and gradient @@ -54,7 +54,7 @@ program test_batch_norm_layer end select ! Initialize the batch normalization layer - bn_layer = batch_norm(num_features) + bn_layer = batchnorm(num_features) call bn_layer % init(input_layer) ! Perform forward and backward passes @@ -65,7 +65,7 @@ program test_batch_norm_layer call bn_layer % get_output(output) if (.not. all(abs(output - sample_input) < tolerance)) then ok = .false. - write(stderr, '(a)') 'batch_norm layer output should be close to input.. failed' + write(stderr, '(a)') 'batchnorm layer output should be close to input.. failed' end if ! Retrieve gamma and beta gradients @@ -75,15 +75,15 @@ program test_batch_norm_layer if (.not. all(beta_grad == sum(gradient))) then ok = .false. - write(stderr, '(a)') 'batch_norm layer beta gradients are incorrect.. failed' + write(stderr, '(a)') 'batchnorm layer beta gradients are incorrect.. failed' end if ! Report test results if (ok) then - print '(a)', 'test_batch_norm_layer: All tests passed.' + print '(a)', 'test_batchnorm_layer: All tests passed.' else - write(stderr, '(a)') 'test_batch_norm_layer: One or more tests failed.' + write(stderr, '(a)') 'test_batchnorm_layer: One or more tests failed.' stop 1 end if -end program test_batch_norm_layer +end program test_batchnorm_layer From b1e0d399def9d05cd0e6673ae547c51083c9f35f Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 24 Aug 2023 11:41:37 -0400 Subject: [PATCH 6/9] Just creating the batchnorm layer for now; actual tests TODO --- test/test_batchnorm_layer.f90 | 66 +++++++++++------------------------ 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/test/test_batchnorm_layer.f90 b/test/test_batchnorm_layer.f90 index 253ad6f7..473b22de 100644 --- a/test/test_batchnorm_layer.f90 +++ b/test/test_batchnorm_layer.f90 @@ -1,13 +1,12 @@ program test_batchnorm_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: batchnorm, input, layer - use nf_input3d_layer, only: input3d_layer + use nf, only: batchnorm, layer use nf_batchnorm_layer, only: batchnorm_layer implicit none - type(layer) :: bn_layer, input_layer + type(layer) :: bn_layer integer, parameter :: num_features = 64 real, allocatable :: sample_input(:,:) real, allocatable :: output(:,:) @@ -29,54 +28,31 @@ program test_batchnorm_layer write(stderr, '(a)') 'batchnorm layer should not be marked as initialized yet.. failed' end if - input_layer = input(input_shape) - call bn_layer % init(input_layer) - - if (.not. bn_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'batchnorm layer should now be marked as initialized.. failed' - end if - - if (.not. all(bn_layer % input_layer_shape == [num_features])) then - ok = .false. - write(stderr, '(a)') 'batchnorm layer input layer shape should be correct.. failed' - end if - ! Initialize sample input and gradient allocate(sample_input(num_features, 1)) allocate(gradient(num_features, 1)) sample_input = 1.0 gradient = 2.0 - ! Set input for the input layer - select type(this_layer => input_layer % p); type is(input3d_layer) - call this_layer % set(sample_input) - end select - - ! Initialize the batch normalization layer - bn_layer = batchnorm(num_features) - call bn_layer % init(input_layer) - - ! Perform forward and backward passes - call bn_layer % forward(input_layer) - call bn_layer % backward(input_layer, gradient) - - ! Retrieve output and check normalization - call bn_layer % get_output(output) - if (.not. all(abs(output - sample_input) < tolerance)) then - ok = .false. - write(stderr, '(a)') 'batchnorm layer output should be close to input.. failed' - end if - - ! Retrieve gamma and beta gradients - allocate(gamma_grad(num_features)) - allocate(beta_grad(num_features)) - call bn_layer % get_gradients(gamma_grad, beta_grad) - - if (.not. all(beta_grad == sum(gradient))) then - ok = .false. - write(stderr, '(a)') 'batchnorm layer beta gradients are incorrect.. failed' - end if + !TODO run forward and backward passes directly on the batchnorm_layer instance + !TODO since we don't yet support tiying in with the input layer. + + !TODO Retrieve output and check normalization + !call bn_layer % get_output(output) + !if (.not. all(abs(output - sample_input) < tolerance)) then + ! ok = .false. + ! write(stderr, '(a)') 'batchnorm layer output should be close to input.. failed' + !end if + + !TODO Retrieve gamma and beta gradients + !allocate(gamma_grad(num_features)) + !allocate(beta_grad(num_features)) + !call bn_layer % get_gradients(gamma_grad, beta_grad) + + !if (.not. all(beta_grad == sum(gradient))) then + ! ok = .false. + ! write(stderr, '(a)') 'batchnorm layer beta gradients are incorrect.. failed' + !end if ! Report test results if (ok) then From e8d040a5c02f9ae5a493984644069a38d0776926 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 24 Aug 2023 12:03:27 -0400 Subject: [PATCH 7/9] Make epsilon a batchnorm variable --- src/nf/nf_batchnorm_layer.f90 | 1 + src/nf/nf_batchnorm_layer_submodule.f90 | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_batchnorm_layer.f90 b/src/nf/nf_batchnorm_layer.f90 index aa2826af..193d5ef3 100644 --- a/src/nf/nf_batchnorm_layer.f90 +++ b/src/nf/nf_batchnorm_layer.f90 @@ -20,6 +20,7 @@ module nf_batchnorm_layer real, allocatable :: gamma_grad(:) real, allocatable :: beta_grad(:) real, allocatable :: input_grad(:,:) + real :: epsilon = 1e-5 contains diff --git a/src/nf/nf_batchnorm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90 index 63a5398f..23998489 100644 --- a/src/nf/nf_batchnorm_layer_submodule.f90 +++ b/src/nf/nf_batchnorm_layer_submodule.f90 @@ -48,12 +48,12 @@ pure module subroutine forward(self, input) self % input = input ! Calculate the normalized input - normalized_input = (input - reshape(self % running_mean, shape(input, 1))) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8) + normalized_input = (input - reshape(self % running_mean, shape(input, 1))) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) ! Batch normalization forward pass - self % output = (reshape(self % gamma, shape(input, 1)) * & - normalized_input) + reshape(self % beta, shape(input, 1)) + self % output = reshape(self % gamma, shape(input, 1)) * normalized_input & + + reshape(self % beta, shape(input, 1)) ! Deallocate temporary array deallocate(normalized_input) @@ -67,13 +67,13 @@ pure module subroutine backward(self, input, gradient) real, intent(in) :: gradient(:,:) ! Calculate gradients for gamma, beta - self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8), dim=2) + self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon), dim=2) self % beta_grad = sum(gradient, dim=2) ! Calculate gradients for input - self % input_grad = gradient * reshape(self % gamma, shape(input, 1)) / & - sqrt(reshape(self % running_var, shape(input, 1)) + 1e-8) + self % input_grad = gradient * reshape(self % gamma, shape(input, 1)) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) end subroutine backward From 17b06101fcc94f8d5f635f42e86a8d56cfa068bf Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 24 Aug 2023 12:09:17 -0400 Subject: [PATCH 8/9] Use associate for normalized input in batchnorm forward pass --- src/nf/nf_batchnorm_layer_submodule.f90 | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_batchnorm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90 index 23998489..d31fca88 100644 --- a/src/nf/nf_batchnorm_layer_submodule.f90 +++ b/src/nf/nf_batchnorm_layer_submodule.f90 @@ -42,21 +42,25 @@ pure module subroutine forward(self, input) implicit none class(batchnorm_layer), intent(in out) :: self real, intent(in) :: input(:,:) - real, allocatable :: normalized_input(:,:) + !real, allocatable :: normalized_input(:,:) ! Store input for backward pass self % input = input - ! Calculate the normalized input - normalized_input = (input - reshape(self % running_mean, shape(input, 1))) & - / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) + associate( & + ! Normalize the input + normalized_input => (input - reshape(self % running_mean, shape(input, 1))) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) & + ) + + ! Batch normalization forward pass + self % output = reshape(self % gamma, shape(input, 1)) * normalized_input & + + reshape(self % beta, shape(input, 1)) - ! Batch normalization forward pass - self % output = reshape(self % gamma, shape(input, 1)) * normalized_input & - + reshape(self % beta, shape(input, 1)) + end associate ! Deallocate temporary array - deallocate(normalized_input) + !deallocate(normalized_input) end subroutine forward From 7fb69f28c7f47d38feede10d6498e398a0aceb7a Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 24 Aug 2023 12:15:02 -0400 Subject: [PATCH 9/9] Remove unused code --- src/nf/nf_batchnorm_layer_submodule.f90 | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/nf/nf_batchnorm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90 index d31fca88..9f3d2a82 100644 --- a/src/nf/nf_batchnorm_layer_submodule.f90 +++ b/src/nf/nf_batchnorm_layer_submodule.f90 @@ -42,7 +42,6 @@ pure module subroutine forward(self, input) implicit none class(batchnorm_layer), intent(in out) :: self real, intent(in) :: input(:,:) - !real, allocatable :: normalized_input(:,:) ! Store input for backward pass self % input = input @@ -59,9 +58,6 @@ pure module subroutine forward(self, input) end associate - ! Deallocate temporary array - !deallocate(normalized_input) - end subroutine forward pure module subroutine backward(self, input, gradient)