Skip to content

[CI] add a big GPU marker to run memory-intensive tests separately on CI #9691

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 38 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
32e23d8
add a marker for big gpu tests
sayakpaul Oct 16, 2024
da92ca0
update
sayakpaul Oct 16, 2024
219a3cc
trigger on PRs temporarily.
sayakpaul Oct 16, 2024
c679563
onnx
sayakpaul Oct 16, 2024
a0bae4b
fix
sayakpaul Oct 16, 2024
95f396e
total memory
sayakpaul Oct 16, 2024
02f0aa3
fixes
sayakpaul Oct 16, 2024
9441016
reduce memory threshold.
sayakpaul Oct 16, 2024
15d1127
bigger gpu
sayakpaul Oct 16, 2024
6c82fd4
Merge branch 'main' into big-model-marker
sayakpaul Oct 16, 2024
676b8a5
empty
sayakpaul Oct 16, 2024
3b50732
g6e
sayakpaul Oct 16, 2024
9ef5435
Apply suggestions from code review
sayakpaul Oct 16, 2024
4ff06b4
address comments.
sayakpaul Oct 17, 2024
46cab82
fix
sayakpaul Oct 17, 2024
2b25688
fix
sayakpaul Oct 17, 2024
b0568da
fix
sayakpaul Oct 17, 2024
928dd73
fix
sayakpaul Oct 17, 2024
9020d8f
fix
sayakpaul Oct 17, 2024
2732720
okay
sayakpaul Oct 17, 2024
f265f7d
further reduce.
sayakpaul Oct 17, 2024
1755305
updates
sayakpaul Oct 17, 2024
fcb57ae
remove
sayakpaul Oct 17, 2024
6f477ac
updates
sayakpaul Oct 17, 2024
ff47576
updates
sayakpaul Oct 17, 2024
1ad8c64
updates
sayakpaul Oct 17, 2024
605a21d
updates
sayakpaul Oct 17, 2024
9e1cacb
fixes
sayakpaul Oct 17, 2024
0704d9a
fixes
sayakpaul Oct 17, 2024
c9fd1ab
updates.
sayakpaul Oct 17, 2024
f8086f6
Merge branch 'main' into big-model-marker
sayakpaul Oct 17, 2024
e31b0bd
Merge branch 'main' into big-model-marker
sayakpaul Oct 18, 2024
cf280ba
fix
sayakpaul Oct 18, 2024
5b9c771
Merge branch 'main' into big-model-marker
a-r-r-o-w Oct 19, 2024
0e07597
Merge branch 'main' into big-model-marker
sayakpaul Oct 22, 2024
4fcd223
Merge branch 'main' into big-model-marker
sayakpaul Oct 31, 2024
1302ecd
Merge branch 'main' into big-model-marker
sayakpaul Oct 31, 2024
2084be0
workflow fixes.
sayakpaul Oct 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,62 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

run_big_gpu_torch_tests:
name: Torch tests on big GPU
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Selected Torch CUDA Test on big GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_big_gpu_torch_cuda_stats.txt
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_big_gpu_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on: docker-tpu
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
) > version.parse("4.33")

USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))

if is_torch_available():
import torch
Expand Down Expand Up @@ -310,6 +311,26 @@ def require_torch_accelerator_with_fp64(test_case):
)


def require_big_gpu_with_torch_cuda(test_case):
"""
Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
etc.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)

import torch

if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case)

device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case)


def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
Expand Down
38 changes: 28 additions & 10 deletions tests/pipelines/controlnet_flux/test_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import unittest

import numpy as np
import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from diffusers import (
Expand All @@ -30,7 +32,8 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
slow,
torch_device,
)
Expand Down Expand Up @@ -180,7 +183,8 @@ def test_xformers_attention_forwardGenerator_pass(self):


@slow
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline

Expand All @@ -199,35 +203,49 @@ def test_canny(self):
"InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16
)
pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
controlnet=controlnet,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "A girl in city, 25 years old, cool, futuristic"
control_image = load_image(
"https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
).resize((512, 512))

prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
)
pooled_prompt_embeds = torch.load(
hf_hub_download(
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
)
)

output = pipe(
prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=2,
guidance_scale=3.5,
max_sequence_length=256,
output_type="np",
height=512,
width=512,
generator=generator,
)

image = output.images[0]

assert image.shape == (1024, 1024, 3)
assert image.shape == (512, 512, 3)

original_image = image[-3:, -3:, -1].flatten()

expected_image = np.array(
[0.33007812, 0.33984375, 0.33984375, 0.328125, 0.34179688, 0.33984375, 0.30859375, 0.3203125, 0.3203125]
)
expected_image = np.array([0.2734, 0.2852, 0.2852, 0.2734, 0.2754, 0.2891, 0.2617, 0.2637, 0.2773])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
71 changes: 0 additions & 71 deletions tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
import unittest

import numpy as np
Expand All @@ -13,9 +12,6 @@
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)

Expand Down Expand Up @@ -222,70 +218,3 @@ def test_fused_qkv_projections(self):
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."


@slow
@require_torch_gpu
class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetImg2ImgPipeline
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test was correctly done as it doesn't pass the controlnet module to the pipeline and it also uses very dummy inputs which I think should be avoided for an integration test. LMK if you think otherwise.

repo_id = "black-forest-labs/FLUX.1-schnell"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)

image = torch.randn(1, 3, 64, 64).to(device)
control_image = torch.randn(1, 3, 64, 64).to(device)

return {
"prompt": "A photo of a cat",
"image": image,
"control_image": control_image,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"controlnet_conditioning_scale": 1.0,
"strength": 0.8,
"output_type": "np",
"generator": generator,
}

@unittest.skip("We cannot run inference on this model with the current CI hardware")
def test_flux_controlnet_img2img_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

inputs = self.get_inputs(torch_device)

image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.36132812, 0.30004883, 0.25830078],
[0.36669922, 0.31103516, 0.23754883],
[0.34814453, 0.29248047, 0.23583984],
[0.35791016, 0.30981445, 0.23999023],
[0.36328125, 0.31274414, 0.2607422],
[0.37304688, 0.32177734, 0.26171875],
[0.3671875, 0.31933594, 0.25756836],
[0.36035156, 0.31103516, 0.2578125],
[0.3857422, 0.33789062, 0.27563477],
[0.3701172, 0.31982422, 0.265625],
],
dtype=np.float32,
)

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

assert max_diff < 1e-4
35 changes: 14 additions & 21 deletions tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel

Expand All @@ -30,7 +31,8 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
slow,
torch_device,
)
Expand Down Expand Up @@ -195,7 +197,8 @@ def test_xformers_attention_forwardGenerator_pass(self):


@slow
@require_torch_gpu
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

Expand Down Expand Up @@ -238,11 +241,9 @@ def test_canny(self):

original_image = image[-3:, -3:, -1].flatten()

expected_image = np.array(
[0.20947266, 0.1574707, 0.19897461, 0.15063477, 0.1418457, 0.17285156, 0.14160156, 0.13989258, 0.30810547]
)
expected_image = np.array([0.7314, 0.7075, 0.6611, 0.7539, 0.7563, 0.6650, 0.6123, 0.7275, 0.7222])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_pose(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Pose", torch_dtype=torch.float16)
Expand Down Expand Up @@ -272,15 +273,12 @@ def test_pose(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.9048, 0.8740, 0.8936, 0.8516, 0.8799, 0.9360, 0.8379, 0.8408, 0.8652])

expected_image = np.array(
[0.8671875, 0.86621094, 0.91015625, 0.8491211, 0.87890625, 0.9140625, 0.8300781, 0.8334961, 0.8623047]
)

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_tile(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX//SD3-Controlnet-Tile", torch_dtype=torch.float16)
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Tile", torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
Expand All @@ -307,12 +305,9 @@ def test_tile(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.6699, 0.6836, 0.6226, 0.6572, 0.7310, 0.6646, 0.6650, 0.6694, 0.6011])

expected_image = np.array(
[0.6982422, 0.7011719, 0.65771484, 0.6904297, 0.7416992, 0.6904297, 0.6977539, 0.7080078, 0.6386719]
)

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2

def test_multi_controlnet(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
Expand Down Expand Up @@ -344,8 +339,6 @@ def test_multi_controlnet(self):
assert image.shape == (1024, 1024, 3)

original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array(
[0.7451172, 0.7416992, 0.7158203, 0.7792969, 0.7607422, 0.7089844, 0.6855469, 0.71777344, 0.7314453]
)
expected_image = np.array([0.7207, 0.7041, 0.6543, 0.7500, 0.7490, 0.6592, 0.6001, 0.7168, 0.7231])

assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
Loading
Loading