diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index 284e62e95..1d6b7f4ba 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .clip import build_clip_model -from .dall_e import Encoder from .data_preprocessor import (CAEDataPreprocessor, RelativeLocDataPreprocessor, RotationPredDataPreprocessor, @@ -27,8 +26,8 @@ __all__ = [ 'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes', 'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention', - 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder', - 'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', + 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA', + 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', 'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm', 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', 'PromptTransformerEncoderLayer', 'build_clip_model' diff --git a/mmselfsup/models/utils/dall_e.py b/mmselfsup/models/utils/dall_e.py deleted file mode 100644 index a026d8071..000000000 --- a/mmselfsup/models/utils/dall_e.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) -# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py -# Copied from BEiT -import math -from collections import OrderedDict -from functools import partial - -import attr -import torch -import torch.nn as nn -import torch.nn.functional as F - - -@attr.s(eq=False) -class Conv2d(nn.Module): - n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) - n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) - kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) - - use_float16: bool = attr.ib(default=True) - device: torch.device = attr.ib(default=torch.device('cpu')) - requires_grad: bool = attr.ib(default=False) - - def __attrs_post_init__(self) -> None: - super().__init__() - - w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), - dtype=torch.float32, - device=self.device, - requires_grad=self.requires_grad) - w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2)) - - b = torch.zeros((self.n_out, ), - dtype=torch.float32, - device=self.device, - requires_grad=self.requires_grad) - self.w, self.b = nn.Parameter(w), nn.Parameter(b) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_float16 and 'cuda' in self.w.device.type: - if x.dtype != torch.float16: - x = x.half() - - w, b = self.w.half(), self.b.half() - else: - if x.dtype != torch.float32: - x = x.float() - - w, b = self.w, self.b - - return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) - - -@attr.s(eq=False, repr=False) -class EncoderBlock(nn.Module): - n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) - n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0) - n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) - - device: torch.device = attr.ib(default=None) - requires_grad: bool = attr.ib(default=False) - - def __attrs_post_init__(self) -> None: - super().__init__() - self.n_hid = self.n_out // 4 - self.post_gain = 1 / (self.n_layers**2) - - make_conv = partial( - Conv2d, device=self.device, requires_grad=self.requires_grad) - self.id_path = make_conv( - self.n_in, self.n_out, - 1) if self.n_in != self.n_out else nn.Identity() - self.res_path = nn.Sequential( - OrderedDict([ - ('relu_1', nn.ReLU()), - ('conv_1', make_conv(self.n_in, self.n_hid, 3)), - ('relu_2', nn.ReLU()), - ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), - ('relu_3', nn.ReLU()), - ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), - ('relu_4', nn.ReLU()), - ('conv_4', make_conv(self.n_hid, self.n_out, 1)), - ])) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.id_path(x) + self.post_gain * self.res_path(x) - - -@attr.s(eq=False, repr=False) -class Encoder(nn.Module): - group_count: int = 4 - n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) - n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) - input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) - vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) - - device: torch.device = attr.ib(default=torch.device('cpu')) - requires_grad: bool = attr.ib(default=False) - use_mixed_precision: bool = attr.ib(default=True) - - def __attrs_post_init__(self) -> None: - super().__init__() - - blk_range = range(self.n_blk_per_group) - n_layers = self.group_count * self.n_blk_per_group - make_conv = partial( - Conv2d, device=self.device, requires_grad=self.requires_grad) - make_blk = partial( - EncoderBlock, - n_layers=n_layers, - device=self.device, - requires_grad=self.requires_grad) - - self.blocks = nn.Sequential( - OrderedDict([ - ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), - ('group_1', - nn.Sequential( - OrderedDict([ - *[(f'block_{i + 1}', - make_blk(1 * self.n_hid, 1 * self.n_hid)) - for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_2', - nn.Sequential( - OrderedDict([ - *[(f'block_{i + 1}', - make_blk( - 1 * self.n_hid if i == 0 else 2 * self.n_hid, - 2 * self.n_hid)) for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_3', - nn.Sequential( - OrderedDict([ - *[(f'block_{i + 1}', - make_blk( - 2 * self.n_hid if i == 0 else 4 * self.n_hid, - 4 * self.n_hid)) for i in blk_range], - ('pool', nn.MaxPool2d(kernel_size=2)), - ]))), - ('group_4', - nn.Sequential( - OrderedDict([ - *[(f'block_{i + 1}', - make_blk( - 4 * self.n_hid if i == 0 else 8 * self.n_hid, - 8 * self.n_hid)) for i in blk_range], - ]))), - ('output', - nn.Sequential( - OrderedDict([ - ('relu', nn.ReLU()), - ('conv', - make_conv( - 8 * self.n_hid, - self.vocab_size, - 1, - use_float16=False)), - ]))), - ])) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.float() - if len(x.shape) != 4: - raise ValueError(f'input shape {x.shape} is not 4d') - if x.shape[1] != self.input_channels: - raise ValueError(f'input has {x.shape[1]} channels but model \ - built for {self.input_channels}') - if x.dtype != torch.float32: - raise ValueError('input must have dtype torch.float32') - - return self.blocks(x)