diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d19186ed..ab53ce146 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,6 @@ repos: rev: 0.7.9 hooks: - id: mdformat - language_version: python3.7 args: ["--number", "--table-width", "200"] additional_dependencies: - mdformat-openmmlab diff --git a/mmgen/models/architectures/ddpm/modules.py b/mmgen/models/architectures/ddpm/modules.py index ba876fd1a..39a4b4a49 100644 --- a/mmgen/models/architectures/ddpm/modules.py +++ b/mmgen/models/architectures/ddpm/modules.py @@ -32,36 +32,39 @@ def forward(self, x, y): return x -@ACTIVATION_LAYERS.register_module() -class SiLU(nn.Module): - r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. - The SiLU function is also known as the swish function. - Args: - input (bool, optional): Use inplace operation or not. - Defaults to `False`. - """ +if 'SiLU' not in ACTIVATION_LAYERS: - def __init__(self, inplace=False): - super().__init__() - if digit_version( - torch.__version__) < digit_version('1.7.0') and inplace: - mmcv.print_log('Inplace version of \'SiLU\' is not supported for ' - f'torch < 1.7.0, found \'{torch.version}\'.') - self.inplace = inplace - - def forward(self, x): - """Forward function for SiLU. + @ACTIVATION_LAYERS.register_module() + class SiLU(nn.Module): + r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. + The SiLU function is also known as the swish function. Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Tensor after activation. + input (bool, optional): Use inplace operation or not. + Defaults to `False`. """ - if digit_version(torch.__version__) < digit_version('1.7.0'): - return x * torch.sigmoid(x) + def __init__(self, inplace=False): + super().__init__() + if digit_version( + torch.__version__) < digit_version('1.7.0') and inplace: + mmcv.print_log('Inplace version of \'SiLU\' is not supported ' + 'for torch < 1.7.0, found ' + f'\'{torch.version}\'.') + self.inplace = inplace + + def forward(self, x): + """Forward function for SiLU. + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Tensor after activation. + """ + + if digit_version(torch.__version__) < digit_version('1.7.0'): + return x * torch.sigmoid(x) - return F.silu(x, inplace=self.inplace) + return F.silu(x, inplace=self.inplace) @MODULES.register_module()