-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path2_preprocess_data.py
executable file
·137 lines (117 loc) · 5.96 KB
/
2_preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
This code
1) splits train/validation/test places,
2) generates MCC chart area mask (mask.png),
3) subtracts default black level for each camera,
3) then crops & resizes the LSMI dataset
"""
import cv2,os,shutil,json
import rawpy
import numpy as np
from tqdm import tqdm
from utils import *
SQUARE_CROP = True # Trim left/right sides of the image so that it is square. (for train/val only, Always True for test)
SIZE = 512 # Size of train/val image. If None, keep the original resolution.
TEST_SIZE = 256 # Size of test image. If None, keep the original resolution.
CAMERA = "galaxy" # LSMI subset camera
MCC_MASKING_SUBSET = {'train'} # Apply MCC masking to these subsets. options: {'train', 'val', 'test'}
if SIZE != None:
DST_ROOT = CAMERA + "_" + str(SIZE)
else:
DST_ROOT = CAMERA + "_fullres"
ZERO_MASK = -1 # Zero mask value for black pixels
RAW = CAMERA+".dng"
TEMPLETE = rawpy.imread(RAW)
if CAMERA == "sony":
BLACK_LEVEL = 128
SATURATION = 4095
else:
BLACK_LEVEL = min(TEMPLETE.black_level_per_channel)
SATURATION = TEMPLETE.white_level
with open(os.path.join(CAMERA,"meta.json"), 'r') as meta_json:
meta_data = json.load(meta_json)
with open(os.path.join(CAMERA,"split.json"), 'r') as split_json:
split_data = json.load(split_json)
for key, places in split_data.items():
split = key.split("_")[-1]
print("Processing "+key)
for place in tqdm(places):
files = [f for f in os.listdir(os.path.join(CAMERA, place)) if f.endswith("tiff")]
dst_path = os.path.join(DST_ROOT,split)
if os.path.isdir(dst_path) == False:
os.makedirs(dst_path)
for file in files:
# if "two_illum" in key and "_12.tiff" not in file:
# continue
# if "three_illum" in key and ("_12.tiff" in file or "_13.tiff" in file):
# continue
fname = os.path.splitext(file)[0]
illum_count = fname.split("_")[1]
# open tiff image & subtract black level
img = cv2.cvtColor(cv2.imread(os.path.join(CAMERA,place,file), cv2.IMREAD_UNCHANGED),cv2.COLOR_BGR2RGB).astype('float32')
img = np.clip(img - BLACK_LEVEL, 0, SATURATION - BLACK_LEVEL)
# make pixel-level illumination map
if len(illum_count) == 1:
mixmap = np.ones_like(img[:,:,0:1],dtype=float)
else:
mixmap = np.load(os.path.join(CAMERA,place,fname+".npy"))
illum_chroma = [[0,0,0],[0,0,0],[0,0,0]]
for i in illum_count:
illum_chroma[int(i)-1] = meta_data[place]["Light"+i]
illum_map = mix_chroma(mixmap,illum_chroma,illum_count)
x_zero,y_zero = np.where(mixmap[:,:,0]==ZERO_MASK)
illum_map[x_zero,y_zero,:] = [1.,1.,1.]
# white balance original image
img_wb = img / illum_map
img_wb = np.clip(img_wb, 0, SATURATION - BLACK_LEVEL)
# apply MCC mask to original image, GT image
mask = np.ones_like(img[:,:,0:1], dtype='float32')
mcc1 = (np.float32(meta_data[place]["MCCCoord"]["mcc1"]) / 2).astype(int)
mcc2 = (np.float32(meta_data[place]["MCCCoord"]["mcc2"]) / 2).astype(int)
mcc3 = (np.float32(meta_data[place]["MCCCoord"]["mcc3"]) / 2).astype(int)
mcc_list = [mcc1.tolist(),mcc2.tolist(),mcc3.tolist()]
for mcc in mcc_list:
contour = np.array([[mcc[0]],[mcc[1]],[mcc[2]],[mcc[3]]]).astype(int)
cv2.drawContours(mask, [contour], 0, (0), -1)
if split in MCC_MASKING_SUBSET:
img = img * mask
img_wb = img_wb * mask
# Crop original image, GT image, mixmap, illum_map, mask
if SQUARE_CROP:
height, width, _ = img.shape
w_start = int(width/2) - int(height/2)
w_end = w_start + height
img = img[:,w_start:w_end,:]
img_wb = img_wb[:,w_start:w_end,:]
mixmap = mixmap[:,w_start:w_end,:]
illum_map = illum_map[:,w_start:w_end,:]
mask = mask[:,w_start:w_end,:]
# prevent negative mask value interpolation if ZERO_MASK is negative value
mixmap = np.where(mixmap==ZERO_MASK,0,mixmap)
# resize & save
if split == 'test':
resize_len = TEST_SIZE
else:
resize_len = SIZE
if resize_len != None:
img = cv2.resize(img, dsize=(resize_len,resize_len), interpolation=cv2.INTER_LINEAR).astype('uint16')
img_wb = cv2.resize(img_wb, dsize=(resize_len,resize_len), interpolation=cv2.INTER_LINEAR).astype('uint16')
mixmap = cv2.resize(mixmap, dsize=(resize_len,resize_len), interpolation=cv2.INTER_LINEAR)
illum_map = cv2.resize(illum_map, dsize=(resize_len,resize_len), interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask, dsize=(resize_len,resize_len), interpolation=cv2.INTER_LINEAR)
else:
img = img.astype('uint16')
img_wb = img_wb.astype('uint16')
# save image, GT image, illum_map, MCC mask, mixmap
cv2.imwrite(os.path.join(dst_path,file), cv2.cvtColor(img,cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(dst_path,fname+"_gt.tiff"), cv2.cvtColor(img_wb,cv2.COLOR_RGB2BGR))
np.save(os.path.join(dst_path,fname+"_illum.npy"), illum_map)
cv2.imwrite(os.path.join(dst_path,place+"_mask.png"), mask)
if len(illum_count) != 1:
np.save(os.path.join(dst_path,fname), mixmap)
# delete original MCC coordinates in meta json
meta_data[place].pop("MCCCoord")
# shutil.copy(os.path.join(CAMERA,"meta.json"),DST_ROOT)
with open(os.path.join(DST_ROOT,"meta.json"), 'w') as out_file:
json.dump(meta_data, out_file, indent=4)
shutil.copy(os.path.join(CAMERA,"split.json"),DST_ROOT)