Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

improve get_crop_region #14709

Merged
merged 1 commit into from
Jan 21, 2024
Merged

improve get_crop_region #14709

merged 1 commit into from
Jan 21, 2024

Conversation

w-e-w
Copy link
Collaborator

@w-e-w w-e-w commented Jan 20, 2024

Description

simplify and improve get_crop_region

test code

performance comparison code

from PIL import Image, ImageDraw
import numpy as np
import timeit


def get_crop_region(mask, pad=0):
    """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
    For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""

    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:, i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:, i] == 0).all():
            break
        crop_right += 1

    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )


def current(in_image, pad=0):
    return get_crop_region(np.array(in_image), pad)


def new_method(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    box = mask_img.getbbox()
    if box:
        x1, y1, x2, y2 = box
    else:  # when no box is found
        x1, y1 = mask_img.size
        x2 = y2 = 0
    return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])

def new_method_with_np_input(mask, pad=0):
    # simulate if input is numpy array
    return new_method(np.array(mask), pad)


if __name__ == '__main__':
    img = Image.new('L', (1000, 1000), color='black')
    np_mask = np.array(img)

    padding = 10

    print(get_crop_region(np_mask, padding))            # (990, 990, 10, 10)
    print(current(img, padding))                        # (990, 990, 10, 10)
    print(new_method(img, padding))                     # (990, 990, 10, 10)
    print(new_method_with_np_input(img, padding))       # (990, 990, 10, 10)

    draw = ImageDraw.Draw(img)
    draw.ellipse((300, 400, 500, 600), fill='white')
    np_mask = np.array(img)

    print(get_crop_region(np_mask, padding))            # (290, 390, 511, 611)
    print(current(img, padding))                        # (290, 390, 511, 611)
    print(new_method(img, padding))                     # (290, 390, 511, 611)
    print(new_method_with_np_input(img, padding))       # (290, 390, 511, 611)

    iterations = 1000

    timeit_0 = timeit.timeit(lambda: get_crop_region(np_mask, padding), number=iterations)
    timeit_1 = timeit.timeit(lambda: current(img, padding), number=iterations)
    timeit_2 = timeit.timeit(lambda: new_method(img, padding), number=iterations)
    timeit_3 = timeit.timeit(lambda: new_method_with_np_input(img, padding), number=iterations)
    print(f"method_0 took {timeit_0:.6f} seconds for {iterations} iterations.")
    print(f"method_1 took {timeit_1:.6f} seconds for {iterations} iterations.")
    print(f"method_2 took {timeit_2:.6f} seconds for {iterations} iterations.")
    print(f"method_3 took {timeit_3:.6f} seconds for {iterations} iterations.")
    # method_0 took 3.534935 seconds for 1000 iterations.
    # method_1 took 4.385469 seconds for 1000 iterations.
    # method_2 took 0.224629 seconds for 1000 iterations.
    # method_3 took 0.967969 seconds for 1000 iterations.

conclusion
all round faster even in compatibility mode

Checklist:

@w-e-w w-e-w marked this pull request as draft January 20, 2024 22:45
@w-e-w w-e-w marked this pull request as ready for review January 20, 2024 23:05
@w-e-w w-e-w force-pushed the improve-get_crop_region branch 5 times, most recently from 3ed1149 to cc1e7e7 Compare January 20, 2024 23:30
@w-e-w
Copy link
Collaborator Author

w-e-w commented Jan 20, 2024

it can be compacted down even more at the cost of readability

def get_crop_region(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    box = mask_img.getbbox()
    return max((box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), max(box[3] + pad, mask_img.size[1])) if box else (min(mask_img.size[0] - pad, 0), min(mask_img.size[1] - pad, 0), max(0 + pad, mask_img.size[0]), max(0 + pad, mask_img.size[1])) 
def stupid(mask, pad=0):
    mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
    return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := mask_img.getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
def shouldnt_ever_exist(mask, pad=0):
    return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
def just_why(mask, pad=0): return (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
lambda_version = lambda mask, pad=0: (max(box[0] - pad, 0), max(box[1] - pad, 0), min(box[2] + pad, mask_img.size[0]), min(box[3] + pad, mask_img.size[1])) if (box := (mask_img := mask if isinstance(mask, Image.Image) else Image.fromarray(mask)).getbbox()) else (max(mask_img.size[0] - pad, 0), max(mask_img.size[1] - pad, 0), min(pad, mask_img.size[0]), min(pad, mask_img.size[1]))
c = lambda m, p=0: (max(b[0] - p, 0), max(b[1] - p, 0), min(b[2] + p, i.size[0]), min(b[3] + p, i.size[1])) if (b := (i := m if isinstance(m, Image.Image) else Image.fromarray(m)).getbbox()) else (max(i.size[0] - p, 0), max(i.size[1] - p, 0), min(p, i.size[0]), min(p, i.size[1]))

@w-e-w w-e-w marked this pull request as draft January 20, 2024 23:57
@w-e-w w-e-w marked this pull request as ready for review January 21, 2024 00:25
@AUTOMATIC1111 AUTOMATIC1111 merged commit 8a6a4ad into dev Jan 21, 2024
6 checks passed
@AUTOMATIC1111 AUTOMATIC1111 deleted the improve-get_crop_region branch January 21, 2024 13:01
@w-e-w w-e-w mentioned this pull request Feb 17, 2024
@pawel665j pawel665j mentioned this pull request Apr 16, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants