From 843475d936607b001378e21e322939820533087b Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 11 Oct 2022 11:31:57 +0800 Subject: [PATCH 1/2] fix SiLU activation --- .pre-commit-config.yaml | 1 - mmgen/models/architectures/ddpm/modules.py | 53 ++++++++++++---------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d19186ed..ab53ce146 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,6 @@ repos: rev: 0.7.9 hooks: - id: mdformat - language_version: python3.7 args: ["--number", "--table-width", "200"] additional_dependencies: - mdformat-openmmlab diff --git a/mmgen/models/architectures/ddpm/modules.py b/mmgen/models/architectures/ddpm/modules.py index ba876fd1a..3fe56f6b3 100644 --- a/mmgen/models/architectures/ddpm/modules.py +++ b/mmgen/models/architectures/ddpm/modules.py @@ -32,36 +32,39 @@ def forward(self, x, y): return x -@ACTIVATION_LAYERS.register_module() -class SiLU(nn.Module): - r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. - The SiLU function is also known as the swish function. - Args: - input (bool, optional): Use inplace operation or not. - Defaults to `False`. - """ +if digit_version(mmcv.__version__) < digit_version('1.6.2'): - def __init__(self, inplace=False): - super().__init__() - if digit_version( - torch.__version__) < digit_version('1.7.0') and inplace: - mmcv.print_log('Inplace version of \'SiLU\' is not supported for ' - f'torch < 1.7.0, found \'{torch.version}\'.') - self.inplace = inplace - - def forward(self, x): - """Forward function for SiLU. + @ACTIVATION_LAYERS.register_module() + class SiLU(nn.Module): + r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. + The SiLU function is also known as the swish function. Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Tensor after activation. + input (bool, optional): Use inplace operation or not. + Defaults to `False`. """ - if digit_version(torch.__version__) < digit_version('1.7.0'): - return x * torch.sigmoid(x) + def __init__(self, inplace=False): + super().__init__() + if digit_version( + torch.__version__) < digit_version('1.7.0') and inplace: + mmcv.print_log('Inplace version of \'SiLU\' is not supported ' + 'for torch < 1.7.0, found ' + f'\'{torch.version}\'.') + self.inplace = inplace + + def forward(self, x): + """Forward function for SiLU. + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Tensor after activation. + """ + + if digit_version(torch.__version__) < digit_version('1.7.0'): + return x * torch.sigmoid(x) - return F.silu(x, inplace=self.inplace) + return F.silu(x, inplace=self.inplace) @MODULES.register_module() From 433858bf9b8c0b410c1d442b15e1745012500838 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Tue, 11 Oct 2022 12:03:01 +0800 Subject: [PATCH 2/2] revise registry condition of SiLU --- mmgen/models/architectures/ddpm/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmgen/models/architectures/ddpm/modules.py b/mmgen/models/architectures/ddpm/modules.py index 3fe56f6b3..39a4b4a49 100644 --- a/mmgen/models/architectures/ddpm/modules.py +++ b/mmgen/models/architectures/ddpm/modules.py @@ -32,7 +32,7 @@ def forward(self, x, y): return x -if digit_version(mmcv.__version__) < digit_version('1.6.2'): +if 'SiLU' not in ACTIVATION_LAYERS: @ACTIVATION_LAYERS.register_module() class SiLU(nn.Module):