diff --git a/lighthouse/feature_extractor/vision_encoders/slowfast.py b/lighthouse/feature_extractor/vision_encoders/slowfast.py index d1edda6..1949617 100644 --- a/lighthouse/feature_extractor/vision_encoders/slowfast.py +++ b/lighthouse/feature_extractor/vision_encoders/slowfast.py @@ -93,8 +93,8 @@ def __call__( slowfast_frames: torch.Tensor, bsz: int = 45): n_chunk = len(slowfast_frames) - features = torch.HalfTensor(n_chunk, self.SLOWFAST_FEATURE_DIM, - device=self._device).fill_(0) + features = torch.zeros([n_chunk, self.SLOWFAST_FEATURE_DIM], + device=self._device, dtype=torch.float16) n_batch = int(math.ceil(n_chunk / bsz)) for i in range(n_batch): st_idx = i * bsz @@ -106,5 +106,4 @@ def __call__( inputs = self._pack_pathway_output(fast_clip) batch_features = self._slowfast_extractor(inputs) features[st_idx:ed_idx] = batch_features.half() - slowfast_features = features.cpu() - return slowfast_features \ No newline at end of file + return features \ No newline at end of file