From d8c009c60dd97163df87ac0d248cab05de53436f Mon Sep 17 00:00:00 2001 From: awkrail Date: Tue, 22 Oct 2024 15:27:10 +0900 Subject: [PATCH] timm --- firefly/vision/timm.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/firefly/vision/timm.py b/firefly/vision/timm.py index 7c0897d..c0abaab 100644 --- a/firefly/vision/timm.py +++ b/firefly/vision/timm.py @@ -9,21 +9,30 @@ from firefly.frame_extractor.frame import VideoFrame from firefly.model_config import _available_timm_models +import torchvision.models as models + class TimmEncoder(BaseVisionEncoder): def __init__( self, device: str, model_path: str, + feature_map: bool = False, preprocess_transforms: Optional[Compose] = None): self._model_path: str = model_path self._device: str = device + self._use_feature_map: bool = feature_map self._available_models = _available_timm_models() if model_path not in self._available_models: raise ValueError(f'{model_path} are not in {self._available_models}.') + + self._model = timm.create_model(model_path, pretrained=True).eval().to(self._device) + if not self._use_feature_map: + self._model = torch.nn.Sequential(*list(self._model.children())[:-1]) - self._model = timm.create_model(model_path, pretrained=True, num_classes=0).eval().to(self._device) self._transforms = preprocess_transforms - + self._model.eval() + + @torch.no_grad() def encode_video( self, @@ -37,7 +46,11 @@ def encode_video( st_idx = i * batch_size ed_idx = (i+1) * batch_size _frames = preprocessed_frames[st_idx:ed_idx].to(self._device) - _video_features = self._model.forward_features(_frames) + if self._use_feature_map: + _video_features = self._model.forward_features(_frames) + else: + _video_features = self._model(_frames) + import ipdb; ipdb.set_trace() video_features.append(_video_features) video_feature_tensor = torch.cat(video_features, dim=0) return video_feature_tensor \ No newline at end of file