forked from ItzikMalkiel/DeepNanoDesign
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloadDatasetForDirectFunction.lua
45 lines (35 loc) · 1.25 KB
/
loadDatasetForDirectFunction.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
--load data for training the direct network that predicts the spectrums
require 'mattorch'
print '==> load train data'
trainSamples = {}
trainFiles = {}
loadDataFromPath('dataset/'..DATASET_NAME..'/train/', trainSamples, trainFiles )
if filterNoPairedData then
trainSamples, trainFileNames = filterNoPairDirectData(trainSamples, trainFiles)
end
print '==> init train data'
columns = 1
rows = 35
labelRows = 43
trainData = torch.FloatTensor(#trainSamples, rows,columns)
trainLabels = torch.FloatTensor(#trainSamples,labelRows)
initData(trainData, trainLabels, trainSamples)
print '==> load test data'
testSamples = {}
testFiles = {}
loadDataFromPath('dataset/'..DATASET_NAME..'/test/', testSamples, testFiles )
if filterNoPairedData then
testSamples, testFileNames = filterNoPairDirectData(testSamples, testFiles)
end
print '==> init test data'
testData = torch.FloatTensor(#testSamples, rows,columns)
testLabels = torch.FloatTensor(#testSamples,labelRows)
initData(testData, testLabels, testSamples)
trainData = trainData:transpose(2,3)
testData = testData:transpose(2,3)
numOfFeatures = trainData:size()[3]
print '==> number of features: '
print (numOfFeatures)
print '==> normalizing the data...'
dofile 'normalizeData.lua'
print '==> normalization is DONE!'