Skip to content

Commit

Permalink
adds option to write predicted classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyling committed Dec 29, 2016
1 parent 8eed6ec commit badd1bc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ The following is a list of complete parameters allowed by the torch code.
* `train_only`: Set to 1 to only train (no testing)
* `test_only`: Given a `.t7` file with model, test on testing data
* `dump_feature_maps_file`: Filename for dumping feature maps of convolution at test time. This will be a `.hdf5` file with fields `feature_maps` for the features at each time step and `word_idxs` for the word indexes (aligned with the last word of the filter). This currently only works for models with a single filter size. This is saved for the best model on fold 1.
* `preds_file`: Filename for writing predictions (with `test_only` set to 1). Output is zero indexed.

Training hyperparameters:
* `num_epochs`: Number of training epochs.
Expand Down
1 change: 1 addition & 0 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cmd:text()
-- Training own dataset
cmd:option('-train_only', 0, 'Set to 1 to only train on data. Default is cross-validation')
cmd:option('-test_only', 0, 'Set to 1 to only do testing. Must have a -warm_start_model')
cmd:option('-preds_file', '', 'On test data, write predictions to an output file. Set test_only to 1 to use')
cmd:option('-warm_start_model', '', 'Path to .t7 file with pre-trained model. Should contain a table with key \'model\'')
cmd:text()

Expand Down
5 changes: 0 additions & 5 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,7 @@ def main():
for word, vec in w2v.items():
embed[word_to_idx[word] - 1] = vec

# Shuffle train
print 'train size:', train.shape
N = train.shape[0]
perm = np.random.permutation(N)
train = train[perm]
train_label = train_label[perm]

filename = dataset + '.hdf5'
with h5py.File(filename, "w") as f:
Expand Down
37 changes: 29 additions & 8 deletions trainer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ function Trainer:test(test_data, test_labels, model, criterion, layers, dump_fea
local confusion = optim.ConfusionMatrix(classes)
confusion:zero()

local preds_file
if opt.test_only == 1 and opt.preds_file ~= '' then
print('Writing predictions to ' .. opt.preds_file)
preds_file = io.open(opt.preds_file, 'w')
end

-- dump feature maps
local feature_maps
local conv_layer = get_layer(model, 'convolution')
Expand All @@ -156,15 +162,26 @@ function Trainer:test(test_data, test_labels, model, criterion, layers, dump_fea
local outputs = model:forward(inputs)
-- dump feature maps from model forward
local cur_feature_maps
if opt.cudnn == 1 then
cur_feature_maps = conv_layer.output:squeeze(4)
else
cur_feature_maps = conv_layer.output
if dump_features then
if opt.cudnn == 1 then
cur_feature_maps = conv_layer.output:squeeze(4)
else
cur_feature_maps = conv_layer.output
end
if feature_maps == nil then
feature_maps = cur_feature_maps
else
feature_maps = torch.cat(feature_maps, cur_feature_maps, 1)
end
end
if feature_maps == nil then
feature_maps = cur_feature_maps
else
feature_maps = torch.cat(feature_maps, cur_feature_maps, 1)

if opt.test_only == 1 and opt.preds_file ~= '' then
-- write predictions to file
local _,preds = torch.max(outputs, 2)
for j = 1, preds:size(1) do
-- zero index
preds_file:write((preds[j][1] - 1) .. '\n')
end
end

local err = criterion:forward(outputs, targets)
Expand Down Expand Up @@ -194,6 +211,10 @@ function Trainer:test(test_data, test_labels, model, criterion, layers, dump_fea
f:close()
end

if opt.test_only == 1 and opt.preds_file ~= '' then
preds_file:close()
end

-- return error percent
confusion:updateValids()
return confusion.totalValid
Expand Down

0 comments on commit badd1bc

Please # to comment.