Skip to content

A better pytorch-based implementation for the mean structural similarity. Differentiable simpler SSIM and MS-SSIM.

License

Notifications You must be signed in to change notification settings

lartpang/mssim.pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mssim.pytorch

$$ \begin{align} l(\mathbf{x}, \mathbf{y}) & = \frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2, K_1=0.01, \\ c(\mathbf{x}, \mathbf{y}) & = \frac{2\sigma_{x}\sigma_{y}+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2, K_2=0.02, \\ s(\mathbf{x}, \mathbf{y}) & = \frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2, \\ \text{SSIM}(\mathbf{x}, \mathbf{y}) & = [l(\mathbf{x}, \mathbf{y})]^\alpha \cdot [c(\mathbf{x}, \mathbf{y})]^\beta \cdot [s(\mathbf{x}, \mathbf{y})]^\gamma \\ & = \frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2)}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}, \\ & \alpha=\beta=\gamma=1, \\ \text{MS-SSIM}(\mathbf{x}, \mathbf{y}) & = [l(\mathbf{x}, \mathbf{y})]^{\alpha_{M}} \cdot \prod^{M}_{j=1} [c_j(\mathbf{x}, \mathbf{y})]^{\beta_j} \cdot [s_j(\mathbf{x}, \mathbf{y})]^{\gamma_j}, (M=5) \\ & \beta_1=\gamma_1=0.0448, \\ & \beta_2=\gamma_2=0.2856, \\ & \beta_3=\gamma_3=0.3001, \\ & \beta_4=\gamma_4=0.2363, \\ & \alpha_5=\beta_5=\gamma_5=0.1333. \end{align} $$

A better pytorch-based implementation for the mean structural similarity (MSSIM).

Compared to this widely used implementation: https://github.com/Po-Hsun-Su/pytorch-ssim, I further optimized and refactored the code.

At the same time, in this implementation, I have dealt with the problem that the calculation with the fp16 mode cannot be consistent with the calculation with the fp32 mode. Typecasting is used here to ensure that the computation is done in fp32 mode. This might also avoid unexpected results when using it as a loss.

Note

2024-12-04: SSIM for 1D, 2D and 3D data, and MS-SSIM calculation for 2D and 3D data are now supported simultaneously.

Setting SSIM1d SSIM2d SSIM3d MS-SSIM2d MS-SSIM3d (only pooling in the spatial domain)
data_dim 1 2 (Default) 3 2 3
return_msssim False False False True True
window_size int, [int] int, [int, int] int, [int, int, int] int, [int, int] int, [int, int, int]
padding int, [int] int, [int, int] int, [int, int, int] int, [int, int] int, [int, int, int]
sigma float, [float] float, [float, float] float, [float, float, float] float, [float, float] float, [float, float, float]
in_channels int int int int int
L 1, 255 1, 255 1, 255 1, 255 1, 255
keep_batch_dim
return_log
ensemble_kernel

Structural similarity index

When comparing images, the mean squared error (MSE)–while simple to implement–is not highly indicative of perceived similarity. Structural similarity aims to address this shortcoming by taking texture into account. More details can be seen at https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity

results

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from lartpang_ssim import SSIM
from po_hsun_su_ssim import SSIM as PoHsunSuSSIM
from vainf_ssim import MS_SSIM as VainFMSSSIM
from vainf_ssim import SSIM as VainFSSIM
from skimage import data, img_as_float

img = img_as_float(data.camera())
rows, cols = img.shape

noise = np.ones_like(img) * 0.3 * (img.max() - img.min())
rng = np.random.default_rng()
noise[rng.random(size=noise.shape) > 0.5] *= -1

img_noise = img + noise
img_const = np.zeros_like(img)

img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
img_noise_tensor = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).float()
img_const_tensor = torch.from_numpy(img_const).unsqueeze(0).unsqueeze(0).float()

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 7))
ax = axes.ravel()

mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean")
mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean")
mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean")

# https://github.com/VainF/pytorch-msssim
vainf_ssim_none = VainFSSIM(channel=1, data_range=1)(img_tensor, img_tensor)
vainf_ssim_noise = VainFSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor)
vainf_ssim_const = VainFSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor)
vainf_ms_ssim_none = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_tensor)
vainf_ms_ssim_noise = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor)
vainf_ms_ssim_const = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor)

# use the settings of https://github.com/VainF/pytorch-msssim
ssim_none_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor)
ssim_noise_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor)
ssim_const_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor)
ms_ssim_none_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor)
ms_ssim_noise_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor)
ms_ssim_const_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor)

# https://github.com/Po-Hsun-Su/pytorch-ssim
pohsunsu_ssim_none = PoHsunSuSSIM()(img_tensor, img_tensor)
pohsunsu_ssim_noise = PoHsunSuSSIM()(img_tensor, img_noise_tensor)
pohsunsu_ssim_const = PoHsunSuSSIM()(img_tensor, img_const_tensor)

# use the settings of https://github.com/Po-Hsun-Su/pytorch-ssim
ssim_none_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_tensor)
ssim_noise_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_noise_tensor)
ssim_const_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_const_tensor)


ax[0].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[0].set_xlabel(
    f"MSE: {mse_none:.6f}\n"
    f"SSIM {ssim_none_0:.6f}, MS-SSIM {ms_ssim_none_0:.6f}\n"
    f"(VainF) SSIM: {vainf_ssim_none:.6f}, MS-SSIM {vainf_ms_ssim_none:.6f}\n"
    f"SSIM {ssim_none_1:.6f}\n"
    f"(PoHsunSu) SSIM: {pohsunsu_ssim_none:.6f}\n"
)
ax[0].set_title("Original image")

ax[1].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[1].set_xlabel(
    f"MSE: {mse_noise:.6f}\n"
    f"SSIM {ssim_noise_0:.6f}, MS-SSIM {ms_ssim_noise_0:.6f}\n"
    f"(VainF) SSIM: {vainf_ssim_noise:.6f}, MS-SSIM {vainf_ms_ssim_noise:.6f}\n"
    f"SSIM {ssim_noise_1:.6f}\n"
    f"(PoHsunSu) SSIM: {pohsunsu_ssim_noise:.6f}\n"
)
ax[1].set_title("Image with noise")

ax[2].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[2].set_xlabel(
    f"MSE: {mse_const:.6f}\n"
    f"SSIM {ssim_const_0:.6f}, MS-SSIM {ms_ssim_const_0:.6f}\n"
    f"(VainF) SSIM: {vainf_ssim_const:.6f}, MS-SSIM {vainf_ms_ssim_const:.6f}\n"
    f"SSIM {ssim_const_1:.6f}\n"
    f"(PoHsunSu) SSIM: {pohsunsu_ssim_const:.6f}\n"
)
ax[2].set_title("Image plus constant")


[ax[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for i in range(len(axes))]

plt.tight_layout()
plt.savefig("results.png")

More Examples

# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim
ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()

# two 4d tensors
x = torch.randn(3, 1, 100, 100).cuda()
y = torch.randn(3, 1, 100, 100).cuda()
ssim_score_0 = ssim_caller(x, y)
# or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result)
with torch.cuda.amp.autocast(enabled=True):
    ssim_score_1 = ssim_caller(x, y)
assert torch.allclose(ssim_score_0, ssim_score_1)
print(ssim_score_0.shape, ssim_score_1.shape)

As A Loss

As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.

Optimize towards MSSIM=1

prediction

import matplotlib.pyplot as plt
import torch
from pytorch_ssim import SSIM
from skimage import data
from torch import optim

original_image = data.moon() / 255
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
predicted_image = torch.zeros_like(
    target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
)
initial_image = predicted_image.clone()

ssim = SSIM().cuda()
initial_ssim_value = ssim(predicted_image, target_image)

ssim_value = initial_ssim_value
optimizer = optim.Adam([predicted_image], lr=0.01)
loss_curves = []
while ssim_value < 0.999:
    ssim_out = 1 - ssim(predicted_image, target_image)
    loss_curves.append(ssim_out.item())
    ssim_value = 1 - ssim_out.item()
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()
    optimizer.zero_grad()

fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
ax = axes.ravel()

ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[0].set_title("Original Image")

ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
ax[1].set_title("Initial Image")

ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
ax[2].set_title("Predicted Image")

ax[3].plot(loss_curves)
ax[3].set_title("SSIM Loss Curve")

ax[4].set_title("Original Image")
ax[4].hist(original_image.ravel(), bins=256)
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[4].set_xlabel("Pixel Intensity")

ax[5].set_title("Initial Image")
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[5].set_xlabel("Pixel Intensity")

ax[6].set_title("Predicted Image")
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[6].set_xlabel("Pixel Intensity")

plt.tight_layout()
plt.savefig("prediction.png")

Optimize towards MSSIM=-1

prediction

import matplotlib.pyplot as plt
import torch
from pytorch_ssim import SSIM
from skimage import data
from torch import optim

original_image = data.moon() / 255
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
predicted_image = torch.zeros_like(
    target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
)
initial_image = predicted_image.clone()

ssim = SSIM(L=original_image.max() - original_image.min()).cuda()
initial_ssim_value = ssim(predicted_image, target_image)

ssim_value = initial_ssim_value
optimizer = optim.Adam([predicted_image], lr=0.01)
loss_curves = []
while ssim_value > -0.94:
    ssim_out = ssim(predicted_image, target_image)
    loss_curves.append(ssim_out.item())
    ssim_value = ssim_out.item()
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()
    optimizer.zero_grad()

fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))
ax = axes.ravel()

ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
ax[0].set_title("Original Image")

ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}")
ax[1].set_title("Initial Image")

ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}")
ax[2].set_title("Predicted Image")

ax[3].plot(loss_curves)
ax[3].set_title("SSIM Loss Curve")

ax[4].set_title("Original Image")
ax[4].hist(original_image.ravel(), bins=256)
ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[4].set_xlabel("Pixel Intensity")

ax[5].set_title("Initial Image")
ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[5].set_xlabel("Pixel Intensity")

ax[6].set_title("Predicted Image")
ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)
ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0))
ax[6].set_xlabel("Pixel Intensity")

plt.tight_layout()
plt.savefig("prediction.png")

Reference

Cite

If you find this library useful, please cite our bibtex:

@online{mssim.pytorch,
    author="lartpang",
    title="{A better pytorch-based implementation for the mean structural similarity. Differentiable simpler SSIM and MS-SSIM.}",
    url="https://github.com/lartpang/mssim.pytorch",
    note="(Jun 21, 2022)",
}