-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstyle_transfer.py
114 lines (97 loc) · 5.16 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import torch
import torch.nn.functional as F
def content_loss(content_weight, content_current, content_target):
"""
Compute the content loss for style transfer.
Inputs:
- content_weight: Scalar giving the weighting for the content loss.
- content_current: features of the current image; this is a PyTorch Tensor of shape
(1, C_l, H_l, W_l).
- content_target: features of the content image, Tensor with shape (1, C_l, H_l, W_l).
Returns:
- scalar content loss
"""
##############################################################################
# YOUR CODE HERE #
##############################################################################
content_loss = content_weight * torch.sum((torch.pow(content_current - content_target, 2)))
return content_loss
##############################################################################
# END OF YOUR CODE #
##############################################################################
def gram_matrix(features, normalize=True):
"""
Compute the Gram matrix from features.
Inputs:
- features: PyTorch Variable of shape (N, C, H, W) giving features for
a batch of N images.
- normalize: optional, whether to normalize the Gram matrix
If True, divide the Gram matrix by the number of neurons (H * W * C)
Returns:
- gram: PyTorch Variable of shape (N, C, C) giving the
(optionally normalized) Gram matrices for the N input images.
"""
##############################################################################
# YOUR CODE HERE #
##############################################################################
N,C,H,W = features.size()
feature_reshaped = features.view(N, C, -1)
#batch multiplication of the matrices
gram = torch.bmm(feature_reshaped, feature_reshaped .transpose(1, 2))
if normalize:
return gram/(H*C*W)
return gram
##############################################################################
# END OF YOUR CODE #
##############################################################################
def style_loss(feats, style_layers, style_targets, style_weights):
"""
Computes the style loss at a set of layers.
Inputs:
- feats: list of the features at every layer of the current image, as produced by
the extract_features function.
- style_layers: List of layer indices into feats giving the layers to include in the
style loss.
- style_targets: List of the same length as style_layers, where style_targets[i] is
a PyTorch Variable giving the Gram matrix the source style image computed at
layer style_layers[i].
- style_weights: List of the same length as style_layers, where style_weights[i]
is a scalar giving the weight for the style loss at layer style_layers[i].
Returns:
- style_loss: A PyTorch Variable holding a scalar giving the style loss.
"""
# Hint: you can do this with one for loop over the style layers, and should
# not be very much code (~5 lines). You will need to use your gram_matrix function.
##############################################################################
# YOUR CODE HERE #
##############################################################################
style_loss = 0
for i in range(len(style_layers)):
gram = gram_matrix(feats[style_layers[i]])
style_loss += (style_weights[i] * torch.sum((torch.pow(gram - style_targets[i], 2))))
return style_loss
##############################################################################
# END OF YOUR CODE #
##############################################################################
def tv_loss(img, tv_weight):
"""
Compute total variation loss.
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.
Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
"""
# Your implementation should be vectorized and not require any loops!
##############################################################################
# YOUR CODE HERE #
##############################################################################
h_variance = torch.sum(torch.pow(img[:,:,:-1,:] - img[:,:,1:,:], 2))
w_variance = torch.sum(torch.pow(img[:,:,:,:-1] - img[:,:,:,1:], 2))
tv_loss = tv_weight * (h_variance + w_variance)
return tv_loss
##############################################################################
# END OF YOUR CODE #
##############################################################################