-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist.R
33 lines (22 loc) · 1.16 KB
/
mnist.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
library( ANTsRNet )
library( keras )
keras::backend()$clear_session()
mnistData <- dataset_mnist()
numberOfLabels <- length( unique( mnistData$train$y ) )
X_train <- array( mnistData$train$x, dim = c( dim( mnistData$train$x ), 1 ) )
Y_train <- keras::to_categorical( mnistData$train$y, numberOfLabels )
# we add a dimension of 1 to specify the channel size
inputImageSize <- c( dim( mnistData$train$x )[2:3], 1 )
alexNetModel <- createAlexNetModel2D( inputImageSize = inputImageSize,
numberOfClassificationLabels = numberOfLabels, numberOfDenseUnits = 4096,
dropoutRate = 0.0 )
alexNetModel %>% compile( loss = 'categorical_crossentropy',
optimizer = optimizer_adam( lr = 0.0001 ),
metrics = c( 'categorical_crossentropy', 'accuracy' ) )
track <- alexNetModel %>% fit( X_train, Y_train, epochs = 40, batch_size = 32,
verbose = 1, shuffle = TRUE, validation_split = 0.2 )
# Now test the model
X_test <- array( mnistData$test$x, dim = c( dim( mnistData$test$x ), 1 ) )
Y_test <- keras::to_categorical( mnistData$test$y, numberOfLabels )
testingMetrics <- alexNetModel %>% evaluate( X_test, Y_test )
predictedData <- alexNetModel %>% predict( X_test, verbose = 1 )