Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Feb 17, 2025
1 parent 8effb91 commit c2ce908
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions tests/context_parallel/test_diffusers_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,34 +67,29 @@ def _test_benchmark_pipe(self, dtype, device, parallelize, compile, max_batch_di

output_image, warmup_time, inference_time = None, None, None

try:
for i in range(2):
torch.manual_seed(0)

call_kwargs = {}
if i == 0:
call_kwargs["num_inference_steps"] = 1
begin = time.time()
output = self.call_pipe(pipe, **call_kwargs)
end = time.time()
if i == 0:
warmup_time = end - begin
msg = f"Warm-up time taken: {warmup_time:.3f} seconds"
else:
inference_time = end - begin
msg = f"Inference time taken: {inference_time:.3f} seconds"
if self.rank == 0:
print(msg)
if i != 0:
if hasattr(output, "images"):
output_image = output.images[0]
elif hasattr(output, "frames"):
video = output.frames[0]
output_image = video[0]
except Exception as e:
if "is not divisible by world_size" in str(e):
pytest.skip(str(e))
raise
for i in range(2):
torch.manual_seed(0)

call_kwargs = {}
if i == 0:
call_kwargs["num_inference_steps"] = 1
begin = time.time()
output = self.call_pipe(pipe, **call_kwargs)
end = time.time()
if i == 0:
warmup_time = end - begin
msg = f"Warm-up time taken: {warmup_time:.3f} seconds"
else:
inference_time = end - begin
msg = f"Inference time taken: {inference_time:.3f} seconds"
if self.rank == 0:
print(msg)
if i != 0:
if hasattr(output, "images"):
output_image = output.images[0]
elif hasattr(output, "frames"):
video = output.frames[0]
output_image = video[0]

return output_image, warmup_time, inference_time

Expand All @@ -108,9 +103,16 @@ class Runner(DiffusionPipelineRunner):

def test_benchmark_pipe(self, extras, dtype, device, parallelize, compile, max_batch_dim_size, max_ring_dim_size):
with self.Runner().start() as runner:
output_image, warmup_time, inference_time = runner(
(dtype, device, parallelize, compile, max_batch_dim_size, max_ring_dim_size),
)
try:
output_image, warmup_time, inference_time = runner(
(dtype, device, parallelize, compile, max_batch_dim_size, max_ring_dim_size),
)
except Exception as e:
lines = str(e).split("\n")
for line in lines:
if "is not divisible by world_size" in line:
pytest.skip(line)
raise

extras.append(pytest_html.extras.html(f"<div><p>Warm-up time taken: {warmup_time:.3f} seconds</p></div>"))
extras.append(pytest_html.extras.html(f"<div><p>Inference time taken: {inference_time:.3f} seconds</p></div>"))
Expand Down

0 comments on commit c2ce908

Please # to comment.