diff --git a/piq/ssim.py b/piq/ssim.py index 74814ce..4861917 100644 --- a/piq/ssim.py +++ b/piq/ssim.py @@ -19,7 +19,7 @@ def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, - downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor]: + downsample: bool = True, k1: float = 0.01, k2: float = 0.03) -> List[torch.Tensor] | torch.Tensor: r"""Interface of Structural Similarity (SSIM) index. Inputs supposed to be in range ``[0, data_range]``. To match performance with skimage and tensorflow set ``'downsample' = True``.