From 97aae46b66d4b4612df00a13049714cc2a293548 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Thu, 12 Jan 2023 21:04:36 +0800 Subject: [PATCH 01/16] add gradient error --- docs/en/api/metrics.rst | 1 + docs/zh_cn/api/metrics.rst | 1 + mmeval/metrics/__init__.py | 3 +- mmeval/metrics/gradient_error.py | 176 ++++++++++++++++++++++ tests/test_metrics/test_gradient_error.py | 19 +++ 5 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 mmeval/metrics/gradient_error.py create mode 100644 tests/test_metrics/test_gradient_error.py diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index a7a275d7..16ee6f01 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -42,3 +42,4 @@ Metrics MAE MSE BLEU + GradientError diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index a7a275d7..16ee6f01 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -42,3 +42,4 @@ Metrics MAE MSE BLEU + GradientError diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 31d2b5bd..1b973299 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -6,6 +6,7 @@ from .coco_detection import COCODetectionMetric from .end_point_error import EndPointError from .f_metric import F1Metric +from .gradient_error import GradientError from .hmean_iou import HmeanIoU from .mae import MAE from .mean_iou import MeanIoU @@ -25,5 +26,5 @@ 'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall', 'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric', - 'AveragePrecision', 'AVAMeanAP', 'BLEU' + 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'GradientError' ] diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py new file mode 100644 index 00000000..6e8ec9ba --- /dev/null +++ b/mmeval/metrics/gradient_error.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np +from typing import Dict, List, Sequence + +from mmeval.core import BaseMetric + + +def gaussian(x, sigma): + """Gaussian function. + + Args: + x (array_like): The independent variable. + sigma (float): Standard deviation of the gaussian function. + + Return: + np.ndarray or scalar: Gaussian value of `x`. + """ + return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) + + +def dgaussian(x, sigma): + """Gradient of gaussian. + + Args: + x (array_like): The independent variable. + sigma (float): Standard deviation of the gaussian function. + + Return: + np.ndarray or scalar: Gradient of gaussian of `x`. + """ + return -x * gaussian(x, sigma) / sigma**2 + + +def gauss_filter(sigma, epsilon=1e-2): + """Gradient of gaussian. + + Args: + sigma (float): Standard deviation of the gaussian kernel. + epsilon (float): Small value used when calculating kernel size. + Default: 1e-2. + + Return: + filter_x (np.ndarray): Gaussian filter along x axis. + filter_y (np.ndarray): Gaussian filter along y axis. + """ + half_size = np.ceil( + sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) + size = int(2 * half_size + 1) + + # create filter in x axis + filter_x = np.zeros((size, size)) + for i in range(size): + for j in range(size): + filter_x[i, j] = gaussian(i - half_size, sigma) * dgaussian( + j - half_size, sigma) + + # normalize filter + norm = np.sqrt((filter_x**2).sum()) + filter_x = filter_x / norm + filter_y = np.transpose(filter_x) + + return filter_x, filter_y + + +def gauss_gradient(img, sigma): + """Gaussian gradient. + + From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ + submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ + index.html + + Args: + img (np.ndarray): Input image. + sigma (float): Standard deviation of the gaussian kernel. + + Return: + np.ndarray: Gaussian gradient of input `img`. + """ + filter_x, filter_y = gauss_filter(sigma) + img_filtered_x = cv2.filter2D( + img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) + img_filtered_y = cv2.filter2D( + img, -1, filter_y, borderType=cv2.BORDER_REPLICATE) + return np.sqrt(img_filtered_x**2 + img_filtered_y**2) + + +class GradientError(BaseMetric): + """Gradient error for evaluating alpha matte prediction. + + Args: + sigma (float): Standard deviation of the gaussian kernel. + Defaults to 1.4 . + norm_const (int): Divide the result to reduce its magnitude. + Defaults to 1000 . + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + Note: + The current implementation assumes the image / alpha / trimap + a numpy array with pixel values ranging from 0 to 255. + + The pred_alpha should be masked by trimap before passing + into this metric. + + The trimap is the most commonly used prior knowledge. As the + name implies, trimap is a ternary graph and each pixel + takes one of {0, 128, 255}, representing the foreground, the + unknown and the background respectively. + + Examples: + >>> from mmeval import GradientError + >>> import numpy as np + >>> + >>> gradienterror = GradientError() + >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) + >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 + >>> trimap = np.zeros((32, 32), dtype=np.uint8) + >>> trimap[:16, :16] = 128 + >>> trimap[16:, 16:] = 255 + >>> gradienterror(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS + {'GradientError': ...} + """ + + def __init__(self, + sigma: float = 1.4, + norm_const: int = 1000, + **kwargs) -> None: + super().__init__(**kwargs) + self.sigma = sigma + self.norm_const = norm_const + + def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add GradientError score of batch to ``self._results`` + + Args: + pred_alpha(Sequence[np.ndarray]): Pred_alpha data of predictions. + ori_alpha(Sequence[np.ndarray]): Ori_alpha data of data_batch. + ori_trimap(Sequence[np.ndarray]): Ori_trimap data of data_batch. + """ + + for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, + trimaps): + assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \ + '`pred_alpha` and `gt_alpha` should be the same, but got: ' \ + f'{pred_alpha.shape} and {gt_alpha.shape}' + + gt_alpha_normed = np.zeros_like(gt_alpha) + pred_alpha_normed = np.zeros_like(pred_alpha) + + cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX) + cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0, + cv2.NORM_MINMAX) + + gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma) + pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma) + # this is the sum over n samples + grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 * + (trimap == 128)).sum() + + # divide by 1000 to reduce the magnitude of the result + grad_loss /= self.norm_const + + self._results.append(grad_loss) + + def compute_metric(self, results: List) -> Dict[str, float]: + """Compute the GradientError metric. + + Args: + results (List): A list that consisting the GradientError score. + This list has already been synced across all ranks. + Returns: + Dict[str, float]: The computed GradientError metric. + The keys are the names of the metrics, + and the values are corresponding results. + """ + + return {'GradientError': float(np.array(results).mean())} diff --git a/tests/test_metrics/test_gradient_error.py b/tests/test_metrics/test_gradient_error.py new file mode 100644 index 00000000..ab833c97 --- /dev/null +++ b/tests/test_metrics/test_gradient_error.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmeval.metrics import GradientError + + +def test_gradient_error(): + np.random.seed(0) + pred_alpha = np.random.randn(32, 32).astype('uint8') + gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 + trimap = np.zeros((32, 32), dtype=np.uint8) + trimap[:16, :16] = 128 + trimap[16:, 16:] = 255 + + gradienterror = GradientError() + gradienterror_results = gradienterror(pred_alpha, gt_alpha, trimap) + assert isinstance(gradienterror_results, dict) + np.testing.assert_almost_equal(gradienterror_results['GradientError'], + 0.0935) From 4aab935e948428a902ffcf1b1b030135bd48d039 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 13:53:24 +0800 Subject: [PATCH 02/16] update docstring and type hints --- mmeval/metrics/gradient_error.py | 47 ++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 6e8ec9ba..9c2865d6 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -6,44 +6,46 @@ from mmeval.core import BaseMetric -def gaussian(x, sigma): +def gaussian(x: np.ndarray, sigma: float): """Gaussian function. Args: x (array_like): The independent variable. sigma (float): Standard deviation of the gaussian function. - Return: + Returns: np.ndarray or scalar: Gaussian value of `x`. """ + return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) -def dgaussian(x, sigma): - """Gradient of gaussian. +def dgaussian(x: np.ndarray, sigma: float): + """Derivative of Gaussian function. Args: x (array_like): The independent variable. sigma (float): Standard deviation of the gaussian function. - Return: + Returns: np.ndarray or scalar: Gradient of gaussian of `x`. """ + return -x * gaussian(x, sigma) / sigma**2 -def gauss_filter(sigma, epsilon=1e-2): - """Gradient of gaussian. +def gauss_filter(sigma: float, epsilon=1e-2): + """Gaussian Filter. Args: sigma (float): Standard deviation of the gaussian kernel. epsilon (float): Small value used when calculating kernel size. Default: 1e-2. - Return: - filter_x (np.ndarray): Gaussian filter along x axis. - filter_y (np.ndarray): Gaussian filter along y axis. + Returns: + tuple(np.ndarray, np.ndarray): Gaussian filter along x and y axis. """ + half_size = np.ceil( sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) size = int(2 * half_size + 1) @@ -63,10 +65,13 @@ def gauss_filter(sigma, epsilon=1e-2): return filter_x, filter_y -def gauss_gradient(img, sigma): - """Gaussian gradient. +def gauss_gradient(img: np.ndarray, sigma: float): + """Gaussian gradient.The first order Gaussian derivative convolution + calculation is carried out by using Gaussian filter. Calculate their + gradients separately, make a difference, and then accumulate their squares. + The more similar the two, the smaller the Gradient error. - From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ + Reference: https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ index.html @@ -74,9 +79,10 @@ def gauss_gradient(img, sigma): img (np.ndarray): Input image. sigma (float): Standard deviation of the gaussian kernel. - Return: + Returns: np.ndarray: Gaussian gradient of input `img`. """ + filter_x, filter_y = gauss_filter(sigma) img_filtered_x = cv2.filter2D( img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) @@ -110,13 +116,13 @@ class GradientError(BaseMetric): >>> from mmeval import GradientError >>> import numpy as np >>> - >>> gradienterror = GradientError() + >>> gradient_error = GradientError() >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 >>> trimap = np.zeros((32, 32), dtype=np.uint8) >>> trimap[:16, :16] = 128 >>> trimap[16:, 16:] = 255 - >>> gradienterror(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS + >>> gradient_error(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS {'GradientError': ...} """ @@ -132,9 +138,9 @@ def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray] """Add GradientError score of batch to ``self._results`` Args: - pred_alpha(Sequence[np.ndarray]): Pred_alpha data of predictions. - ori_alpha(Sequence[np.ndarray]): Ori_alpha data of data_batch. - ori_trimap(Sequence[np.ndarray]): Ori_trimap data of data_batch. + pred_alphas (Sequence[np.ndarray]): Pred_alpha data of predictions. + ori_alphas (Sequence[np.ndarray]): Ori_alpha data of data_batch. + ori_trimaps (Sequence[np.ndarray]): Ori_trimap data of data_batch. """ for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, @@ -156,7 +162,7 @@ def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray] grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 * (trimap == 128)).sum() - # divide by 1000 to reduce the magnitude of the result + # divide by self.norm_const to reduce the magnitude of the result grad_loss /= self.norm_const self._results.append(grad_loss) @@ -167,6 +173,7 @@ def compute_metric(self, results: List) -> Dict[str, float]: Args: results (List): A list that consisting the GradientError score. This list has already been synced across all ranks. + Returns: Dict[str, float]: The computed GradientError metric. The keys are the names of the metrics, From f2de3aeca703615139cf416490c9e35c3d87b8c1 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 13:57:22 +0800 Subject: [PATCH 03/16] update --- mmeval/metrics/gradient_error.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 9c2865d6..6cbe8659 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -100,6 +100,7 @@ class GradientError(BaseMetric): norm_const (int): Divide the result to reduce its magnitude. Defaults to 1000 . **kwargs: Keyword parameters passed to :class:`BaseMetric`. + Note: The current implementation assumes the image / alpha / trimap a numpy array with pixel values ranging from 0 to 255. From 7963b49bc553da66e8e1a65779258bfcbdf34093 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 15:06:16 +0800 Subject: [PATCH 04/16] use try_import --- mmeval/metrics/gradient_error.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 6cbe8659..1aa21a10 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -1,16 +1,21 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 +# Copyright (c) OpenMMLab. All rights reserved.Dict import numpy as np -from typing import Dict, List, Sequence +from typing import TYPE_CHECKING, Dict, List, Sequence from mmeval.core import BaseMetric +from mmeval.utils import try_import + +if TYPE_CHECKING: + import cv2 +else: + cv2 = try_import('cv2') def gaussian(x: np.ndarray, sigma: float): """Gaussian function. Args: - x (array_like): The independent variable. + x (np.ndarray): The independent variable. sigma (float): Standard deviation of the gaussian function. Returns: @@ -24,7 +29,7 @@ def dgaussian(x: np.ndarray, sigma: float): """Derivative of Gaussian function. Args: - x (array_like): The independent variable. + x (np.ndarray): The independent variable. sigma (float): Standard deviation of the gaussian function. Returns: @@ -66,10 +71,7 @@ def gauss_filter(sigma: float, epsilon=1e-2): def gauss_gradient(img: np.ndarray, sigma: float): - """Gaussian gradient.The first order Gaussian derivative convolution - calculation is carried out by using Gaussian filter. Calculate their - gradients separately, make a difference, and then accumulate their squares. - The more similar the two, the smaller the Gradient error. + """Gaussian gradient. Reference: https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ @@ -135,6 +137,10 @@ def __init__(self, self.sigma = sigma self.norm_const = norm_const + if cv2 is None: + raise ImportError(f'For availability of {self.__class__.__name__},' + ' please install cv2 first.') + def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 """Add GradientError score of batch to ``self._results`` From 4503e53e315000416465046ee858950cd3a1d266 Mon Sep 17 00:00:00 2001 From: TKAN <40657922+xuan07472@users.noreply.github.com> Date: Fri, 13 Jan 2023 15:23:29 +0800 Subject: [PATCH 05/16] Update gradient_error.py update --- mmeval/metrics/gradient_error.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 1aa21a10..a90a6fb4 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -116,6 +116,7 @@ class GradientError(BaseMetric): unknown and the background respectively. Examples: + >>> from mmeval import GradientError >>> import numpy as np >>> From 651168abc72ef6a145d35db76383adbf5cdcc8fd Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 15:40:38 +0800 Subject: [PATCH 06/16] update --- mmeval/metrics/gradient_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index a90a6fb4..d34f1395 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -116,7 +116,7 @@ class GradientError(BaseMetric): unknown and the background respectively. Examples: - + >>> from mmeval import GradientError >>> import numpy as np >>> From afed4942778ac5345b5ae6ae5fd33cc3384a1a20 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 15:47:49 +0800 Subject: [PATCH 07/16] update Examples --- mmeval/metrics/gradient_error.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index d34f1395..3f33769f 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -121,7 +121,8 @@ class GradientError(BaseMetric): >>> import numpy as np >>> >>> gradient_error = GradientError() - >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) + >>> np.random.seed(0) + >>> pred_alpha = np.random.randn(32, 32).astype('uint8') >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 >>> trimap = np.zeros((32, 32), dtype=np.uint8) >>> trimap[:16, :16] = 128 From 8c6e46ef44727e5c705115699b3961e32fcc899f Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 16:32:44 +0800 Subject: [PATCH 08/16] update test --- tests/test_metrics/test_gradient_error.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_metrics/test_gradient_error.py b/tests/test_metrics/test_gradient_error.py index ab833c97..92dbe44d 100644 --- a/tests/test_metrics/test_gradient_error.py +++ b/tests/test_metrics/test_gradient_error.py @@ -12,8 +12,7 @@ def test_gradient_error(): trimap[:16, :16] = 128 trimap[16:, 16:] = 255 - gradienterror = GradientError() - gradienterror_results = gradienterror(pred_alpha, gt_alpha, trimap) - assert isinstance(gradienterror_results, dict) - np.testing.assert_almost_equal(gradienterror_results['GradientError'], - 0.0935) + gradient_error = GradientError() + metric_results = gradient_error(pred_alpha, gt_alpha, trimap) + assert isinstance(metric_results, dict) + np.testing.assert_almost_equal(metric_results['GradientError'], 0.0935) From 086c77943aa6e4adca356274416a5729f7e7e11f Mon Sep 17 00:00:00 2001 From: TKAN <40657922+xuan07472@users.noreply.github.com> Date: Fri, 13 Jan 2023 18:13:18 +0800 Subject: [PATCH 09/16] Apply suggestions from code review update Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> --- mmeval/metrics/gradient_error.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 3f33769f..18b42c2a 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -45,7 +45,7 @@ def gauss_filter(sigma: float, epsilon=1e-2): Args: sigma (float): Standard deviation of the gaussian kernel. epsilon (float): Small value used when calculating kernel size. - Default: 1e-2. + Default to 1e-2. Returns: tuple(np.ndarray, np.ndarray): Gaussian filter along x and y axis. @@ -100,7 +100,7 @@ class GradientError(BaseMetric): sigma (float): Standard deviation of the gaussian kernel. Defaults to 1.4 . norm_const (int): Divide the result to reduce its magnitude. - Defaults to 1000 . + Defaults to 1000. **kwargs: Keyword parameters passed to :class:`BaseMetric`. Note: From 412d91f0abacfb39c76797e4166e93d29a63bd79 Mon Sep 17 00:00:00 2001 From: TKAN <40657922+xuan07472@users.noreply.github.com> Date: Fri, 13 Jan 2023 21:00:37 +0800 Subject: [PATCH 10/16] Update mmeval/metrics/gradient_error.py Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> --- mmeval/metrics/gradient_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 18b42c2a..3e61c2e6 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -39,7 +39,7 @@ def dgaussian(x: np.ndarray, sigma: float): return -x * gaussian(x, sigma) / sigma**2 -def gauss_filter(sigma: float, epsilon=1e-2): +def gauss_filter(sigma: float, epsilon: float = 1e-2): """Gaussian Filter. Args: From 6e48201ff858ab79ef97325905ccebe9f7908d9c Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 21:38:57 +0800 Subject: [PATCH 11/16] Correct the Args --- mmeval/metrics/gradient_error.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 3e61c2e6..041447d1 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -147,9 +147,12 @@ def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray] """Add GradientError score of batch to ``self._results`` Args: - pred_alphas (Sequence[np.ndarray]): Pred_alpha data of predictions. - ori_alphas (Sequence[np.ndarray]): Ori_alpha data of data_batch. - ori_trimaps (Sequence[np.ndarray]): Ori_trimap data of data_batch. + pred_alphas (Sequence[np.ndarray]): Predict the probability + that pixels belong to the foreground. + gt_alphas (Sequence[np.ndarray]): Probability that the actual + pixel belongs to the foreground. + trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap + consists of foreground and unknown region. """ for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, @@ -189,4 +192,4 @@ def compute_metric(self, results: List) -> Dict[str, float]: and the values are corresponding results. """ - return {'GradientError': float(np.array(results).mean())} + return {'gradient_error': float(np.array(results).mean())} From 60dbe9a580500345e71ada2cebaab1420f387ac7 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Fri, 13 Jan 2023 23:50:36 +0800 Subject: [PATCH 12/16] update test --- tests/test_metrics/test_gradient_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_metrics/test_gradient_error.py b/tests/test_metrics/test_gradient_error.py index 92dbe44d..e6e9bfe0 100644 --- a/tests/test_metrics/test_gradient_error.py +++ b/tests/test_metrics/test_gradient_error.py @@ -15,4 +15,4 @@ def test_gradient_error(): gradient_error = GradientError() metric_results = gradient_error(pred_alpha, gt_alpha, trimap) assert isinstance(metric_results, dict) - np.testing.assert_almost_equal(metric_results['GradientError'], 0.0935) + np.testing.assert_almost_equal(metric_results['gradient_error'], 0.0935) From 016c5eb0459e34a1ec2221a4ba09d1b2c569a85b Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Sat, 14 Jan 2023 17:13:39 +0800 Subject: [PATCH 13/16] Change metrics to lowercase --- mmeval/metrics/gradient_error.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 041447d1..704532b6 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -128,7 +128,7 @@ class GradientError(BaseMetric): >>> trimap[:16, :16] = 128 >>> trimap[16:, 16:] = 255 >>> gradient_error(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS - {'GradientError': ...} + {'gradient_error': ...} """ def __init__(self, @@ -144,7 +144,7 @@ def __init__(self, ' please install cv2 first.') def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 - """Add GradientError score of batch to ``self._results`` + """Add gradient_error score of batch to ``self._results`` Args: pred_alphas (Sequence[np.ndarray]): Predict the probability @@ -180,14 +180,14 @@ def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray] self._results.append(grad_loss) def compute_metric(self, results: List) -> Dict[str, float]: - """Compute the GradientError metric. + """Compute the gradient_error metric. Args: - results (List): A list that consisting the GradientError score. + results (List): A list that consisting the gradient_error score. This list has already been synced across all ranks. Returns: - Dict[str, float]: The computed GradientError metric. + Dict[str, float]: The computed gradient_error metric. The keys are the names of the metrics, and the values are corresponding results. """ From 23d4014b521d86010d308f0e0bc3737afdcab279 Mon Sep 17 00:00:00 2001 From: xuan07472 Date: Sat, 14 Jan 2023 18:38:14 +0800 Subject: [PATCH 14/16] update --- mmeval/metrics/gradient_error.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index 704532b6..e5376f4e 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -144,7 +144,7 @@ def __init__(self, ' please install cv2 first.') def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 - """Add gradient_error score of batch to ``self._results`` + """Add GradientError score of batch to ``self._results`` Args: pred_alphas (Sequence[np.ndarray]): Predict the probability @@ -180,14 +180,14 @@ def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray] self._results.append(grad_loss) def compute_metric(self, results: List) -> Dict[str, float]: - """Compute the gradient_error metric. + """Compute the GradientError metric. Args: - results (List): A list that consisting the gradient_error score. + results (List): A list that consisting the GradientError score. This list has already been synced across all ranks. Returns: - Dict[str, float]: The computed gradient_error metric. + Dict[str, float]: The computed GradientError metric. The keys are the names of the metrics, and the values are corresponding results. """ From 44ed2578d217b6453babc935e92448868475370b Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 15 Jan 2023 11:36:14 +0800 Subject: [PATCH 15/16] Update mmeval/metrics/gradient_error.py --- mmeval/metrics/gradient_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmeval/metrics/gradient_error.py b/mmeval/metrics/gradient_error.py index e5376f4e..99b559fc 100644 --- a/mmeval/metrics/gradient_error.py +++ b/mmeval/metrics/gradient_error.py @@ -141,7 +141,7 @@ def __init__(self, if cv2 is None: raise ImportError(f'For availability of {self.__class__.__name__},' - ' please install cv2 first.') + ' please install opencv-python first.') def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 """Add GradientError score of batch to ``self._results`` From 9267768ce286c650e548d014b3d4c6a74171c441 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 15 Jan 2023 11:48:23 +0800 Subject: [PATCH 16/16] Apply suggestions from code review --- docs/en/api/metrics.rst | 2 +- docs/zh_cn/api/metrics.rst | 2 +- mmeval/metrics/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index 583b8a0d..00d24b3b 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -43,4 +43,4 @@ Metrics MSE BLEU SAD - GradientError \ No newline at end of file + GradientError diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index 583b8a0d..00d24b3b 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -43,4 +43,4 @@ Metrics MSE BLEU SAD - GradientError \ No newline at end of file + GradientError diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index f94eb4c3..2479d788 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -27,5 +27,5 @@ 'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall', 'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric', - 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'SAD', 'GradientError + 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'SAD', 'GradientError' ]