From 9d1a3218e3949fd55b379b0f08c81930b4dfef05 Mon Sep 17 00:00:00 2001 From: Ajinkya Indulkar <26824103+AjinkyaIndulkar@users.noreply.github.com> Date: Sun, 12 Jun 2022 14:36:27 +0200 Subject: [PATCH] Verify camera ID when running dot in camera mode (#18) * :sparkles: add cam_exists * :white_check_mark: fix unit test * :recycle: refactor cam_exists() logic * :bug: fix pre-commit errors --- dot/__main__.py | 4 +++- dot/commons/camera_utils.py | 23 +++++++++++++++++++++-- dot/commons/model_option.py | 4 +++- tests/pipeline_test.py | 2 +- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/dot/__main__.py b/dot/__main__.py index c9fea9b..b32ac0b 100644 --- a/dot/__main__.py +++ b/dot/__main__.py @@ -121,7 +121,9 @@ def main( ): """CLI entrypoint for dot.""" # initialize dot - _dot = DOT(use_video=use_video, use_image=use_image, save_folder=save_folder) + _dot = DOT( + use_video=use_video, use_image=use_image, save_folder=save_folder, target=target + ) # build dot option = _dot.build_option( diff --git a/dot/commons/camera_utils.py b/dot/commons/camera_utils.py index 830ea92..e154423 100644 --- a/dot/commons/camera_utils.py +++ b/dot/commons/camera_utils.py @@ -14,7 +14,26 @@ from .video.videocaptureasync import VideoCaptureAsync +def fetch_camera(target: int) -> VideoCaptureAsync: + """Fetches a VideoCaptureAsync object. + + Args: + target (int): Camera ID descriptor. + + Raises: + ValueError: If camera ID descriptor is not valid. + + Returns: + VideoCaptureAsync: VideoCaptureAsync object. + """ + try: + return VideoCaptureAsync(target) + except RuntimeError: + raise ValueError(f"Camera {target} does not exist.") + + def camera_pipeline( + cap: VideoCaptureAsync, source: str, target: int, change_option: Callable[[np.ndarray], None], @@ -22,11 +41,12 @@ def camera_pipeline( post_process_image: Callable[[np.ndarray], np.ndarray], crop_size: int = 224, show_fps: bool = False, - **kwargs: Dict + **kwargs: Dict, ) -> None: """Open a webcam stream `target` and performs face-swap based on `source` image by frame. Args: + cap (VideoCaptureAsync): VideoCaptureAsync object. source (str): Path to source image folder. target (int): Camera ID descriptor. change_option (Callable[[np.ndarray], None]): Set `source` arg as faceswap source image. @@ -51,7 +71,6 @@ def camera_pipeline( img_a_align_crop = process_image(img_a_whole) img_a_align_crop = post_process_image(img_a_align_crop) - cap = VideoCaptureAsync(target) cap.start() ret, frame = cap.read() cv2.namedWindow("cam", cv2.WINDOW_GUI_NORMAL) diff --git a/dot/commons/model_option.py b/dot/commons/model_option.py index ce5807b..569f6bb 100644 --- a/dot/commons/model_option.py +++ b/dot/commons/model_option.py @@ -12,7 +12,7 @@ import torch from ..gpen.face_enhancement import FaceEnhancement -from .camera_utils import camera_pipeline +from .camera_utils import camera_pipeline, fetch_camera from .utils import find_images_from_path, generate_random_file_idx, rand_idx_tuple from .video.video_utils import video_pipeline @@ -180,8 +180,10 @@ def generate_from_camera( show_fps (bool, optional): Show FPS. Defaults to False. """ with torch.no_grad(): + cap = fetch_camera(target) self.create_model(opt_crop_size=opt_crop_size, **kwargs) camera_pipeline( + cap, source, target, self.change_option, diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index a8695d2..de578ac 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -17,7 +17,7 @@ def fake_generate(self, option, source, target, show_fps=False, **kwargs): @mock.patch.object(DOT, "generate", fake_generate) class TestDotOptions(unittest.TestCase): def setUp(self): - self._dot = DOT(use_cam=False, save_folder="./tests") + self._dot = DOT(use_image=True, save_folder="./tests") self.faceswap_cv2_option = self._dot.faceswap_cv2(False, False, None)