-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathinit.lua
64 lines (52 loc) · 1.93 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
--Copyright (C) 2016 Hani Altwaijry
--Released under MIT License
--license available in LICENSE file
require 'torch'
require 'libnpy4th'
require 'xlua'
local npy4th = {}
local help = {
loadnpy = [[loadnpy(filepath) -- Loads a numpy .npy file to a torch.Tensor]],
loadnpz = [[loadnpz(filepath) -- Loads a numpy .npz file to a table]],
savenpy = [[savenpy(filepath, tensor) -- Saves a torch tensor in .npy format]]
}
local typeIds = {}
typeIds['torch.DoubleTensor']=0
typeIds['torch.FloatTensor']=1
typeIds['torch.IntTensor']=2
typeIds['torch.ByteTensor']=3
typeIds['torch.LongTensor']=4
typeIds['torch.ShortTensor']=5
typeIds['torch.CudaTensor']=1 -- saved as float
npy4th.loadnpy = function(filepath)
if not filepath then
xlua.error('file path must be supplied',
'npy4th.loadnpy',
help.loadnpy)
end
return libnpy4th.loadnpy(filepath)
end
npy4th.loadnpz = function(filepath)
if not filepath then
xlua.error('file path must be supplied',
'npy4th.loadnpz',
help.loadnpz)
end
return libnpy4th.loadnpz(filepath)
end
npy4th.savenpy = function(filepath, tensor, mode)
if not filepath then
xlua.error('file path must be supplied',
'npy4th.savenpy',
help.savenpy)
end
if not tensor or (type(tensor) =='userdata' and tensor.__typename ~= nil and typeIds[tensor:type()] == nil ) then
xlua.error('Must pass a torch.*Tensor or unsupported tensor type', 'npy4th.savenpy', help.savenpy)
end
if tensor:type()=='torch.CudaTensor' then
tensor = tensor:float() -- convert it to float
end
mode = mode or 'w'
return libnpy4th.savenpy(filepath, tensor, typeIds[tensor:type()], mode)
end
return npy4th