Skip to content

Commit

Permalink
timm
Browse files Browse the repository at this point in the history
  • Loading branch information
awkrail committed Oct 22, 2024
1 parent 4a35ad7 commit d8c009c
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions firefly/vision/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@
from firefly.frame_extractor.frame import VideoFrame
from firefly.model_config import _available_timm_models

import torchvision.models as models

class TimmEncoder(BaseVisionEncoder):
def __init__(
self,
device: str,
model_path: str,
feature_map: bool = False,
preprocess_transforms: Optional[Compose] = None):
self._model_path: str = model_path
self._device: str = device
self._use_feature_map: bool = feature_map
self._available_models = _available_timm_models()
if model_path not in self._available_models:
raise ValueError(f'{model_path} are not in {self._available_models}.')

self._model = timm.create_model(model_path, pretrained=True).eval().to(self._device)
if not self._use_feature_map:
self._model = torch.nn.Sequential(*list(self._model.children())[:-1])

self._model = timm.create_model(model_path, pretrained=True, num_classes=0).eval().to(self._device)
self._transforms = preprocess_transforms

self._model.eval()


@torch.no_grad()
def encode_video(
self,
Expand All @@ -37,7 +46,11 @@ def encode_video(
st_idx = i * batch_size
ed_idx = (i+1) * batch_size
_frames = preprocessed_frames[st_idx:ed_idx].to(self._device)
_video_features = self._model.forward_features(_frames)
if self._use_feature_map:
_video_features = self._model.forward_features(_frames)
else:
_video_features = self._model(_frames)
import ipdb; ipdb.set_trace()
video_features.append(_video_features)
video_feature_tensor = torch.cat(video_features, dim=0)
return video_feature_tensor

0 comments on commit d8c009c

Please # to comment.