-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
76 lines (57 loc) · 1.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from filter import FastGuidedFilter
def get_image(path):
"""
Reads and returns RGB image, (1,3,H,W).
"""
image = torch.from_numpy(np.array(Image.open(path))).float()
image = image / torch.max(image)
image = torch.movedim(image, -1, 0).unsqueeze(0).cuda()
return image
def get_v_component(img_hsv):
"""
Assumes (1,3,H,W) HSV image.
"""
return img_hsv[:,-1].unsqueeze(0)
def replace_v_component(img_hsv, v_new):
"""
Replaces the V component of a HSV image (1,3,H,W).
"""
img_hsv[:,-1] = v_new
return img_hsv
def interpolate_image(img, H, W):
"""
Reshapes the image based on new resolution.
"""
return F.interpolate(img, size=(H,W))
def get_coords(H, W):
"""
Creates a coordinates grid for INF.
"""
coords = np.dstack(np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W)))
coords = torch.from_numpy(coords).float().cuda()
return coords
def get_patches(img, KERNEL_SIZE):
"""
Creates a tensor where the channel contains patch information.
"""
kernel = torch.zeros((KERNEL_SIZE ** 2, 1, KERNEL_SIZE, KERNEL_SIZE)).cuda()
for i in range(KERNEL_SIZE):
for j in range(KERNEL_SIZE):
kernel[int(torch.sum(kernel).item()),0,i,j] = 1
pad = nn.ReflectionPad2d(KERNEL_SIZE//2)
im_padded = pad(img)
extracted = torch.nn.functional.conv2d(im_padded, kernel, padding=0).squeeze(0)
return torch.movedim(extracted, 0, -1)
def filter_up(x_lr, y_lr, x_hr, r=1):
"""
Applies the guided filter to upscale the predicted image.
"""
guided_filter = FastGuidedFilter(r=r)
y_hr = guided_filter(x_lr, y_lr, x_hr)
y_hr = torch.clip(y_hr, 0, 1)
return y_hr