Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Implementing basic RNN #162

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
64e1b69
Prototyping RNN layer based on Dense
castelao Oct 26, 2023
8c11911
Extenting uses
castelao Oct 26, 2023
adef7d7
Reading coefficients from h5f model
castelao Oct 26, 2023
b51d66f
feat: get_params()
castelao Oct 26, 2023
a797502
feat: set_params()
castelao Oct 26, 2023
ff1c392
feat: get_num_params()
castelao Oct 26, 2023
f686950
Initializing recurrent kernel and states
castelao Oct 26, 2023
acf1afd
feat: forward()
castelao Oct 26, 2023
69fed32
More informative error messages
castelao Oct 27, 2023
fd24e16
Minor adjustments on rnn_layer
castelao Oct 27, 2023
7415081
Constructor for RNN
castelao Oct 28, 2023
6f56863
Loading rnn constructor in the root
castelao Oct 28, 2023
ad598a8
Back to 1D concept
castelao Oct 31, 2023
0ae7af1
fix: Recurrent is actually a square matrix
castelao Oct 31, 2023
c164924
Apply loss function if RNN is the output layer
castelao Nov 1, 2023
55ad96d
fix: Getting biases
castelao Nov 1, 2023
b345865
Allowing backward 1D from dense to RNN
castelao Nov 1, 2023
91b85e0
Allowing backward 1D from RNN
castelao Nov 1, 2023
5e197f0
Allowing forward from dense to RNN
castelao Nov 1, 2023
7f671c8
Allowing forward from RNN
castelao Nov 1, 2023
c27f59c
Getting output from RNN
castelao Nov 1, 2023
524d2c4
feat: Implementing reset state for RNN
castelao Nov 1, 2023
598f9e7
refactor: set_state() on layer level
castelao Nov 6, 2023
b7bead6
wip: A simple RNN example
castelao Nov 14, 2023
088e4f3
feat: layer getting gradient from RNN
castelao Nov 14, 2023
4d0a4fd
feat: layer setting params for RNN
castelao Nov 14, 2023
ee516a8
Might not use set_state at rnn_layer level
castelao Nov 14, 2023
07f7587
fix: New access point to 'loss % derivative'
castelao Jun 25, 2024
9b22826
Define set_state as pure
castelao Jun 30, 2024
5bc9bc5
fix: pure interface for set_state
castelao Jun 30, 2024
4c7c0b9
fix: Conciliating with latest main state
castelao Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -49,6 +49,8 @@ add_library(neural-fortran
src/nf/nf_random.f90
src/nf/nf_reshape_layer.f90
src/nf/nf_reshape_layer_submodule.f90
src/nf/nf_rnn_layer.f90
src/nf/nf_rnn_layer_submodule.f90
src/nf/io/nf_io_binary.f90
src/nf/io/nf_io_binary_submodule.f90
)
48 changes: 48 additions & 0 deletions example/simple_rnn.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
program simple_rnn
use nf, only: dense, input, network, rnn, sgd
implicit none
type(network) :: net
real, allocatable :: x(:), y(:), p(:)
integer, parameter :: num_iterations = 1000
integer :: n, l

allocate(p(2))

print '("Simple RNN")'
print '(60("="))'

net = network([ &
input(3), &
rnn(5), &
rnn(1) &
])

call net % print_info()

x = [0.2, 0.4, 0.6]
y = [0.123456, 0.246802]

do n = 0, num_iterations

do l = 1, size(net % layers)
if (net % layers(l) % name == 'rnn') call net % layers(l) % set_state()
end do

if (mod(n, 100) == 0) then
p(1:1) = net % predict(x)
p(2:2) = net % predict(x)
print '(i4,2(3x,f8.6))', n, p

else

call net % forward(x)
call net % backward(y(1:1))
call net % update(optimizer=sgd(learning_rate=.001))
call net % forward(x)
call net % backward(y(2:2))
call net % update(optimizer=sgd(learning_rate=.001))
end if

end do

end program simple_rnn
2 changes: 1 addition & 1 deletion src/nf.f90
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ module nf
use nf_datasets_mnist, only: label_digits, load_mnist
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
conv2d, dense, flatten, input, maxpool2d, reshape, rnn
use nf_loss, only: mse, quadratic
use nf_metrics, only: corr, maxabs
use nf_network, only: network
6 changes: 6 additions & 0 deletions src/nf/nf_layer.f90
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ module nf_layer
procedure :: get_params
procedure :: get_gradients
procedure :: set_params
procedure :: set_state
procedure :: init
procedure :: print_info

@@ -153,6 +154,11 @@ module subroutine set_params(self, params)
!! Parameters of this layer
end subroutine set_params

pure module subroutine set_state(self, state)
class(layer), intent(inout) :: self
real, intent(in), optional :: state(:)
end subroutine set_state

end interface

end module nf_layer
25 changes: 24 additions & 1 deletion src/nf/nf_layer_constructors.f90
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ module nf_layer_constructors
implicit none

private
public :: conv2d, dense, flatten, input, maxpool2d, reshape
public :: conv2d, dense, flatten, input, maxpool2d, reshape, rnn

interface input

@@ -166,6 +166,29 @@ module function reshape(output_shape) result(res)
!! Resulting layer instance
end function reshape

pure module function rnn(layer_size, activation) result(res)
!! Recurrent (fully-connected) layer constructor.
!!
!! This layer is a building block for recurrent, fully-connected
!! networks, or for an output layer of a convolutional network.
!! A recurrent layer must not be the first layer in the network.
!!
!! Example:
!!
!! ```
!! use nf, only :: rnn, layer, relu
!! type(layer) :: rnn_layer
!! rnn_layer = rnn(10)
!! rnn_layer = rnn(10, activation=relu())
!! ```
integer, intent(in) :: layer_size
!! The number of neurons in a dense layer
class(activation_function), intent(in), optional :: activation
!! Activation function instance (default tanh)
type(layer) :: res
!! Resulting layer instance
end function rnn

end interface

end module nf_layer_constructors
26 changes: 25 additions & 1 deletion src/nf/nf_layer_constructors_submodule.f90
Original file line number Diff line number Diff line change
@@ -8,7 +8,8 @@
use nf_input3d_layer, only: input3d_layer
use nf_maxpool2d_layer, only: maxpool2d_layer
use nf_reshape_layer, only: reshape3d_layer
use nf_activation, only: activation_function, relu, sigmoid
use nf_rnn_layer, only: rnn_layer
use nf_activation, only: activation_function, relu, sigmoid, tanhf

implicit none

@@ -134,4 +135,27 @@ module function reshape(output_shape) result(res)

end function reshape

pure module function rnn(layer_size, activation) result(res)
integer, intent(in) :: layer_size
class(activation_function), intent(in), optional :: activation
type(layer) :: res

class(activation_function), allocatable :: activation_tmp

res % name = 'rnn'
res % layer_shape = [layer_size]

if (present(activation)) then
allocate(activation_tmp, source=activation)
else
allocate(activation_tmp, source=tanhf())
end if

res % activation = activation_tmp % get_name()

allocate(res % p, source=rnn_layer(layer_size, activation_tmp))

end function rnn


end submodule nf_layer_constructors_submodule
64 changes: 60 additions & 4 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
use nf_input3d_layer, only: input3d_layer
use nf_maxpool2d_layer, only: maxpool2d_layer
use nf_reshape_layer, only: reshape3d_layer
use nf_rnn_layer, only: rnn_layer
use nf_optimizers, only: optimizer_base_type

contains
@@ -32,6 +33,8 @@ pure module subroutine backward_1d(self, previous, gradient)
call this_layer % backward(prev_layer % output, gradient)
type is(flatten_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(rnn_layer)
call this_layer % backward(prev_layer % output, gradient)
end select

type is(flatten_layer)
@@ -46,6 +49,19 @@ pure module subroutine backward_1d(self, previous, gradient)
call this_layer % backward(prev_layer % output, gradient)
end select

type is(rnn_layer)

select type(prev_layer => previous % p)
type is(input1d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(dense_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(flatten_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(rnn_layer)
call this_layer % backward(prev_layer % output, gradient)
end select

end select

end subroutine backward_1d
@@ -123,6 +139,8 @@ pure module subroutine forward(self, input)
call this_layer % forward(prev_layer % output)
type is(flatten_layer)
call this_layer % forward(prev_layer % output)
type is(rnn_layer)
call this_layer % forward(prev_layer % output)
end select

type is(conv2d_layer)
@@ -179,6 +197,19 @@ pure module subroutine forward(self, input)
call this_layer % forward(prev_layer % output)
end select

type is(rnn_layer)

! Upstream layers permitted: input1d, dense, rnn
select type(prev_layer => input % p)
type is(input1d_layer)
call this_layer % forward(prev_layer % output)
type is(dense_layer)
call this_layer % forward(prev_layer % output)
type is(rnn_layer)
call this_layer % forward(prev_layer % output)
end select


end select

end subroutine forward
@@ -197,6 +228,8 @@ pure module subroutine get_output_1d(self, output)
allocate(output, source=this_layer % output)
type is(flatten_layer)
allocate(output, source=this_layer % output)
type is(rnn_layer)
allocate(output, source=this_layer % output)
class default
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'

@@ -292,8 +325,10 @@ elemental module function get_num_params(self) result(num_params)
num_params = 0
type is (reshape3d_layer)
num_params = 0
type is (rnn_layer)
num_params = this_layer % get_num_params()
class default
error stop 'Unknown layer type.'
error stop 'get_num_params() with unknown layer type.'
end select

end function get_num_params
@@ -317,8 +352,10 @@ module function get_params(self) result(params)
! No parameters to get.
type is (reshape3d_layer)
! No parameters to get.
type is (rnn_layer)
params = this_layer % get_params()
class default
error stop 'Unknown layer type.'
error stop 'get_params() with unknown layer type.'
end select

end function get_params
@@ -342,8 +379,10 @@ module function get_gradients(self) result(gradients)
! No gradients to get.
type is (reshape3d_layer)
! No gradients to get.
type is (rnn_layer)
gradients = this_layer % get_gradients()
class default
error stop 'Unknown layer type.'
error stop 'get_gradients() with unknown layer type.'
end select

end function get_gradients
@@ -399,10 +438,27 @@ module subroutine set_params(self, params)
write(stderr, '(a)') 'Warning: calling set_params() ' &
// 'on a zero-parameter layer; nothing to do.'

type is (rnn_layer)
call this_layer % set_params(params)

class default
error stop 'Unknown layer type.'
error stop 'set_params() with unknown layer type.'
end select

end subroutine set_params

pure module subroutine set_state(self, state)
class(layer), intent(inout) :: self
real, intent(in), optional :: state(:)

select type (this_layer => self % p)
type is (rnn_layer)
if (present(state)) then
this_layer % state = state
else
this_layer % state = 0
end if
end select
end subroutine set_state

end submodule nf_layer_submodule
64 changes: 63 additions & 1 deletion src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
@@ -7,8 +7,9 @@
use nf_input3d_layer, only: input3d_layer
use nf_maxpool2d_layer, only: maxpool2d_layer
use nf_reshape_layer, only: reshape3d_layer
use nf_rnn_layer, only: rnn_layer
use nf_layer, only: layer
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape, rnn
use nf_loss, only: quadratic
use nf_optimizers, only: optimizer_base_type, sgd
use nf_parallel, only: tile_indices
@@ -93,6 +94,59 @@ module function network_from_layers(layers) result(res)
end function network_from_layers


pure function get_activation_by_name(activation_name) result(res)
! Workaround to get activation_function with some
! hardcoded default parameters by its name.
! Need this function since we get only activation name
! from keras files.
character(len=*), intent(in) :: activation_name
class(activation_function), allocatable :: res

select case(trim(activation_name))
case('elu')
allocate ( res, source = elu(alpha = 0.1) )

case('exponential')
allocate ( res, source = exponential() )

case('gaussian')
allocate ( res, source = gaussian() )

case('linear')
allocate ( res, source = linear() )

case('relu')
allocate ( res, source = relu() )

case('leaky_relu')
allocate ( res, source = leaky_relu(alpha = 0.1) )

case('sigmoid')
allocate ( res, source = sigmoid() )

case('softmax')
allocate ( res, source = softmax() )

case('softplus')
allocate ( res, source = softplus() )

case('step')
allocate ( res, source = step() )

case('tanh')
allocate ( res, source = tanhf() )

case('celu')
allocate ( res, source = celu() )

case default
error stop 'activation_name must be one of: ' // &
'"elu", "exponential", "gaussian", "linear", "relu", ' // &
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".'
end select

end function get_activation_by_name

module subroutine backward(self, output, loss)
class(network), intent(in out) :: self
real, intent(in) :: output(:)
@@ -128,6 +182,11 @@ module subroutine backward(self, output, loss)
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
type is(rnn_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
end select
else
! Hidden layer; take the gradient from the next layer
@@ -523,6 +582,9 @@ module subroutine update(self, optimizer, batch_size)
type is(conv2d_layer)
this_layer % dw = 0
this_layer % db = 0
type is(rnn_layer)
this_layer % dw = 0
this_layer % db = 0
end select
end do

Loading
Loading