Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Refactor] Support resizing pos_embed while loading ckpt and format output #1488

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions mmpretrain/models/backbones/vit_sam.py
Original file line number Diff line number Diff line change
@@ -344,6 +344,14 @@ class ViTSAM(BaseBackbone):
channel reduction layer is disabled. Defaults to 256.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
out_type (str): The type of output features. Please choose from

- ``"raw"`` or ``"featmap"``: The feature map tensor from the
patch tokens with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).

Defaults to ``"raw"``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
@@ -392,6 +400,7 @@ class ViTSAM(BaseBackbone):
'global_attn_indexes': [7, 15, 23, 31]
}),
}
OUT_TYPES = {'raw', 'featmap', 'avg_featmap'}

def __init__(self,
arch: str = 'base',
@@ -400,6 +409,7 @@ def __init__(self,
in_channels: int = 3,
out_channels: int = 256,
out_indices: int = -1,
out_type: str = 'raw',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
qkv_bias: bool = True,
@@ -444,7 +454,12 @@ def __init__(self,
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
# num_patches = self.patch_resolution[0] * self.patch_resolution[1]

# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type

self.use_abs_pos = use_abs_pos
self.interpolate_mode = interpolate_mode
@@ -453,6 +468,11 @@ def __init__(self,
self.pos_embed = nn.Parameter(
torch.zeros(1, *self.patch_resolution, self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)

if use_rel_pos:
self._register_load_state_dict_pre_hook(
self._prepare_relative_position)

if isinstance(out_indices, int):
out_indices = [out_indices]
@@ -565,8 +585,71 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
x = layer(x)

if i in self.out_indices:
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)

if self.out_channels > 0:
x = self.channel_reduction(x.permute(0, 3, 1, 2))
outs.append(x)
x = self.channel_reduction(x)
outs.append(self._format_output(x))

return tuple(outs)

def _format_output(self, x) -> torch.Tensor:
if self.out_type == 'raw' or self.out_type == 'featmap':
return x
elif self.out_type == 'avg_featmap':
# (B, C, H, W) -> (B, C, N) -> (B, N, C)
x = x.flatten(2).permute(0, 2, 1)
return x.mean(dim=1)

def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return

ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')

ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3]
pos_embed_shape = self.patch_embed.init_out_size

flattened_pos_embed = state_dict[name].flatten(1, 2)
resized_pos_embed = resize_pos_embed(flattened_pos_embed,
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode, 0)
state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape,
self.embed_dims)

def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs):
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'rel_pos_' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
relative_position_pretrained = state_dict[ckpt_key]
relative_position_current = state_dict_model[key]
L1, _ = relative_position_pretrained.size()
L2, _ = relative_position_current.size()
if L1 != L2:
new_rel_pos = F.interpolate(
relative_position_pretrained.reshape(1, L1,
-1).permute(
0, 2, 1),
size=L2,
mode='linear',
)
new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0)
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(f'Resize the {ckpt_key} from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos.shape}')
state_dict[ckpt_key] = new_rel_pos