Skip to content

Commit

Permalink
Verify camera ID when running dot in camera mode (#18)
Browse files Browse the repository at this point in the history
* ✨ add cam_exists

* ✅ fix unit test

* ♻️ refactor cam_exists() logic

* 🐛 fix pre-commit errors
  • Loading branch information
ajndkr authored Jun 12, 2022
1 parent e3f4b40 commit 9d1a321
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
4 changes: 3 additions & 1 deletion dot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def main(
):
"""CLI entrypoint for dot."""
# initialize dot
_dot = DOT(use_video=use_video, use_image=use_image, save_folder=save_folder)
_dot = DOT(
use_video=use_video, use_image=use_image, save_folder=save_folder, target=target
)

# build dot
option = _dot.build_option(
Expand Down
23 changes: 21 additions & 2 deletions dot/commons/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,39 @@
from .video.videocaptureasync import VideoCaptureAsync


def fetch_camera(target: int) -> VideoCaptureAsync:
"""Fetches a VideoCaptureAsync object.
Args:
target (int): Camera ID descriptor.
Raises:
ValueError: If camera ID descriptor is not valid.
Returns:
VideoCaptureAsync: VideoCaptureAsync object.
"""
try:
return VideoCaptureAsync(target)
except RuntimeError:
raise ValueError(f"Camera {target} does not exist.")


def camera_pipeline(
cap: VideoCaptureAsync,
source: str,
target: int,
change_option: Callable[[np.ndarray], None],
process_image: Callable[[np.ndarray], np.ndarray],
post_process_image: Callable[[np.ndarray], np.ndarray],
crop_size: int = 224,
show_fps: bool = False,
**kwargs: Dict
**kwargs: Dict,
) -> None:
"""Open a webcam stream `target` and performs face-swap based on `source` image by frame.
Args:
cap (VideoCaptureAsync): VideoCaptureAsync object.
source (str): Path to source image folder.
target (int): Camera ID descriptor.
change_option (Callable[[np.ndarray], None]): Set `source` arg as faceswap source image.
Expand All @@ -51,7 +71,6 @@ def camera_pipeline(
img_a_align_crop = process_image(img_a_whole)
img_a_align_crop = post_process_image(img_a_align_crop)

cap = VideoCaptureAsync(target)
cap.start()
ret, frame = cap.read()
cv2.namedWindow("cam", cv2.WINDOW_GUI_NORMAL)
Expand Down
4 changes: 3 additions & 1 deletion dot/commons/model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from ..gpen.face_enhancement import FaceEnhancement
from .camera_utils import camera_pipeline
from .camera_utils import camera_pipeline, fetch_camera
from .utils import find_images_from_path, generate_random_file_idx, rand_idx_tuple
from .video.video_utils import video_pipeline

Expand Down Expand Up @@ -180,8 +180,10 @@ def generate_from_camera(
show_fps (bool, optional): Show FPS. Defaults to False.
"""
with torch.no_grad():
cap = fetch_camera(target)
self.create_model(opt_crop_size=opt_crop_size, **kwargs)
camera_pipeline(
cap,
source,
target,
self.change_option,
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def fake_generate(self, option, source, target, show_fps=False, **kwargs):
@mock.patch.object(DOT, "generate", fake_generate)
class TestDotOptions(unittest.TestCase):
def setUp(self):
self._dot = DOT(use_cam=False, save_folder="./tests")
self._dot = DOT(use_image=True, save_folder="./tests")

self.faceswap_cv2_option = self._dot.faceswap_cv2(False, False, None)

Expand Down

0 comments on commit 9d1a321

Please # to comment.