-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
54 lines (48 loc) · 2.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import pickle
import numpy as np
def create_patches(num_v, patches, dimensions, overlap):
"""
patches: (x, y, z) tuple of # patches to split each dimension into
dimensions: (x, y, z) tuple of size of each dimension
overlap: size of overlap between adjacent patches
Return an array of patch arrays, each of which is a set of indices into the
visible layer. Can also be replaced with logical indexing (e.g. boolean mask).
"""
if sum(patches) == 3:
return [slice(None, None, None)]
final = []
# Patches for division of visible layer
if patches is None:
final = np.ones((1, num_v)) # All units included
else:
assert(dimensions)
assert(num_v == dimensions[0] * dimensions[1] * dimensions[2])
assert(dimensions[0] % patches[0] == 0 and
dimensions[1] % patches[1] == 0 and
dimensions[2] % patches[2] == 0) # Check even split of dimensions
assert(all(1 <= d <= 2 for d in patches)) # Can split each dim into 1 or 2 patches
xp = dimensions[0] / patches[0]
yp = dimensions[1] / patches[1]
zp = dimensions[2] / patches[2]
for x in range(patches[0]):
for y in range(patches[1]):
for z in range(patches[2]):
patch = np.zeros(dimensions)
patch[max(0, x*xp-overlap):min(dimensions[0], (x + 1)*xp + overlap),
max(0, y*yp-overlap):min(dimensions[1], (y + 1)*yp + overlap),
max(0, z*zp-overlap):min(dimensions[2], (z + 1)*zp + overlap)] = True
final.append(np.where(patch.flatten())[0])
# Check that all patches have same number of on units
assert(all(x.size == final[0].size for x in final))
return np.asanyarray(final)
def save_model(m, name):
# Save as .model file if no extension specified
if not os.path.splitext(name)[1]:
name += ".model"
with open(name, 'wb') as f:
pickle.dump(m, f)
def load_model(name):
if os.path.isfile(name):
return pickle.load(open(name, "rb"))
return None