-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvideo_feature_extractor.py
62 lines (50 loc) · 2.26 KB
/
video_feature_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
import time
class VideoFeatureExtractor():
def __init__(self):
# Init Pytorch pretrained model and preprocessing module
self.model = models.mobilenet_v2(pretrained=True)
# remove last fully-connected layer
self.model.classifier = nn.Sequential(*list(self.model.classifier.children())[:-1])
self.preprocess = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_features(self, video_path, start_time, end_time, N_sample_frames = 5):
# if start_time == None or end_time == None:
cap = cv2.VideoCapture(video_path)
total_frame_N = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# fps = cam.get(cv2.CAP_PROP_FPS)
start_frame_id = int(start_time * fps)
end_frame_id = int(end_time * fps)
sampled_frames = []
sample_frame_ids = np.linspace(start_frame_id, end_frame_id, N_sample_frames)
sample_frame_ids = sample_frame_ids.astype(int)
for i in sample_frame_ids:
ret, frame = cap.read()
# OpenCV read image in BGR order. we should transfer it to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
sampled_frames += [frame]
cap.set(1, i)
input_imgs_tensors = [self.preprocess(img) for img in sampled_frames]
input_imgs_tensors = torch.stack(input_imgs_tensors)
with torch.no_grad():
out_features = self.model(input_imgs_tensors) # 1280 * frames
out_features = out_features.mean(axis=0) # 1280
return out_features
if __name__ == "__main__":
video_feature_extractor = VideoFeatureExtractor()
s_t = time.time()
features = video_feature_extractor.get_features(video_path='./data/index.mp4')
print(features)
print("features.shape, type(features)", features.shape, type(features))
print("Feature extract time for one video:", time.time() - s_t)
# plt.imshow(sampled_frames[1])