Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[CodeCamp #1538] Adopt mmeval's GradientError in mmediting #1598

Closed
wants to merge 11 commits into from
79 changes: 34 additions & 45 deletions mmagic/evaluation/metrics/gradient_error.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence
from typing import Optional, Sequence

import cv2
import numpy as np
import torch.nn as nn
from mmengine.model import is_model_wrapper
from mmeval import GradientError as _GradientError
from torch.utils.data.dataloader import DataLoader

from mmagic.registry import METRICS
from ..functional import gauss_gradient
from .base_sample_wise_metric import BaseSampleWiseMetric
from .metrics_utils import _fetch_data_and_check, average
from .metrics_utils import _fetch_data_and_check


@METRICS.register_module()
class GradientError(BaseSampleWiseMetric):
class GradientError(_GradientError):
"""Gradient error for evaluating alpha matte prediction.

.. note::
Expand All @@ -28,10 +25,14 @@ class GradientError(BaseSampleWiseMetric):
into this metric

Args:
sigma (float): Standard deviation of the gaussian kernel.
Defaults to 1.4 .
norm_const (int): Divide the result to reduce its magnitude.
Defaults to 1000 .
scaling (float, optional): Scaling factor for final metric.
E.g. scaling=100 means the final metric will be amplified by 100
for output. Default: 1

prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None.

Default prefix: ''

Expand All @@ -40,15 +41,17 @@ class GradientError(BaseSampleWiseMetric):
"""

metric = 'GradientError'
SAMPLER_MODE = 'normal'
sample_model = 'orig'

def __init__(
self,
sigma=1.4,
norm_constant=1000,
scaling: float = 1,
prefix: Optional[str] = None,
**kwargs,
) -> None:
self.sigma = sigma
self.norm_constant = norm_constant
self.prefix = prefix
self.scaling = scaling
super().__init__(**kwargs)

def prepare(self, module: nn.Module, dataloader: DataLoader):
Expand All @@ -68,39 +71,25 @@ def process(self, data_batch: Sequence[dict],
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

pred_alphas, gt_alphas, trimaps = [], [], []
for data_sample in data_samples:
pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)
pred_alphas.append(pred_alpha)
gt_alphas.append(gt_alpha)
trimaps.append(trimap)

gt_alpha_normed = np.zeros_like(gt_alpha)
pred_alpha_normed = np.zeros_like(pred_alpha)

cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX)
cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0,
cv2.NORM_MINMAX)

gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma)
pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma)
# this is the sum over n samples
grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 *
(trimap == 128)).sum()

# divide by 1000 to reduce the magnitude of the result
grad_loss /= self.norm_constant
self.add(pred_alphas, gt_alphas, trimaps)

self.results.append({'grad_err': grad_loss})
def evaluate(self, *args, **kwargs):
"""Returns metric results and print pretty table of metrics per class.

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
This method would be invoked by ``mmengine.Evaluator``.
"""

grad_err = average(results, 'grad_err')

return {'GradientError': grad_err}
metric_results = self.compute(*args, **kwargs)
self.reset()

key_template = f'{self.prefix}/{{}}' if self.prefix else '{}'
return {
key_template.format(k): v * self.scaling
for k, v in metric_results.items()
}