Skip to content

⚡️ Speed up method Kandinsky3ConditionalGroupNorm.forward by 7% #11667

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

misrasaurabh1
Copy link
Contributor

📄 7% (0.07x) speedup for Kandinsky3ConditionalGroupNorm.forward in src/diffusers/models/unets/unet_kandinsky3.py

⏱️ Runtime : 2.16 milliseconds 2.02 milliseconds (best of 332 runs)

📝 Explanation and details

Certainly! Here are the most important optimizations for this program, based on the line profiling results.

  • The main bottleneck is self.norm(x) * (scale + 1.0) + shift and the self.context_mlp(context) call.
  • The loop that repeatedly applies unsqueeze to the context tensor is inefficient.
  • You can vectorize context expansion using .view or .reshape to match the desired broadcastable shape all at once, rather than unsqueezing in a loop.

The improved code below removes the loop, performs shape expansion more efficiently, and should provide speedups for larger batch sizes or channel/image sizes.

Summary of Optimizations.

  • Removed for-loop with efficient tensor view: The repetitive unsqueeze calls are replaced with a single view, which is much faster for matching the broadcasting shape.
  • Precompute and reuse shapes: Uses x.dim() to compute required shape for broadcast once, no per-dimension Python looping.
  • All existing semantics and output shapes preserved.
  • No unnecessary temp allocations or autograd-op graph buildup.

This rewrite keeps the function signatures and logic unchanged, but should yield notable performance improvements, especially for large spatial tensors.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 31 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import pytest  # used for our unit tests
import torch  # used for tensor operations
from src.diffusers.models.unets.unet_kandinsky3 import \
    Kandinsky3ConditionalGroupNorm
from torch import nn

# unit tests

# ---- Basic Test Cases ----

def test_forward_basic_2d_batch():
    # Test with 2D spatial input, batch size 2, 4 channels, 2 groups, context_dim 8
    batch, channels, height, width = 2, 4, 8, 8
    groups = 2
    context_dim = 8
    x = torch.randn(batch, channels, height, width)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output
    # Output should be differentiable
    out.sum().backward()

def test_forward_basic_1d():
    # Test with 1D input (e.g., sequence), batch size 3, 6 channels, 3 groups, context_dim 4
    batch, channels, length = 3, 6, 16
    groups = 3
    context_dim = 4
    x = torch.randn(batch, channels, length)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_basic_3d():
    # Test with 3D input (e.g., video), batch size 1, 8 channels, 2 groups, context_dim 10
    batch, channels, d, h, w = 1, 8, 2, 4, 4
    groups = 2
    context_dim = 10
    x = torch.randn(batch, channels, d, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_context_zero_affine():
    # Test that with zero-initialized context_mlp, output equals GroupNorm(x)
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 7
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    # Since context_mlp is zero-initialized, scale=0, shift=0, so output=GroupNorm(x)
    codeflash_output = model.forward(x, context); out = codeflash_output
    baseline = model.norm(x)

# ---- Edge Test Cases ----

def test_forward_single_element_batch():
    # Test with batch size 1
    batch, channels, h, w = 1, 4, 5, 5
    groups = 2
    context_dim = 3
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_single_channel():
    # Test with single channel (groups=1)
    batch, channels, h, w = 2, 1, 4, 4
    groups = 1
    context_dim = 5
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_single_spatial():
    # Test with single spatial dimension (e.g., length=1)
    batch, channels, length = 2, 4, 1
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, length)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_mismatched_context_dim():
    # Test with wrong context_dim (should raise an error)
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 7
    x = torch.randn(batch, channels, h, w)
    wrong_context = torch.randn(batch, context_dim + 1)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    with pytest.raises(RuntimeError):
        model.forward(x, wrong_context)

def test_forward_mismatched_batch_size():
    # Test with mismatched batch size between x and context (should raise error)
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 5
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch + 1, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    with pytest.raises(RuntimeError):
        model.forward(x, context)

def test_forward_invalid_groups():
    # Test with groups not dividing channels evenly (should raise error)
    batch, channels, h, w = 2, 5, 8, 8
    groups = 3  # 5 not divisible by 3
    context_dim = 4
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    with pytest.raises(ValueError):
        model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
        model.forward(x, context)

def test_forward_empty_input():
    # Test with empty input tensor (should raise error)
    batch, channels, h, w = 0, 4, 8, 8
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    with pytest.raises(Exception):
        model.forward(x, context)

def test_forward_nan_inf_input():
    # Test with NaN and Inf values in x
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, h, w)
    x[0, 0, 0, 0] = float('nan')
    x[1, 1, 1, 1] = float('inf')
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_nan_inf_context():
    # Test with NaN and Inf values in context
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    context[0, 0] = float('nan')
    context[1, 1] = float('inf')
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

# ---- Large Scale Test Cases ----

def test_forward_large_batch():
    # Test with large batch size
    batch, channels, h, w = 128, 4, 8, 8
    groups = 2
    context_dim = 8
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_channels():
    # Test with large number of channels (but less than 100MB)
    batch, channels, h, w = 2, 256, 8, 8  # 2*256*8*8*4B = 131072B = 128KB
    groups = 16
    context_dim = 32
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_spatial():
    # Test with large spatial dimensions (but less than 100MB)
    batch, channels, h, w = 2, 8, 64, 64  # 2*8*64*64*4B = 131072B = 2MB
    groups = 4
    context_dim = 8
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_3d():
    # Test with large 3D input (e.g., volumetric data)
    batch, channels, d, h, w = 1, 16, 16, 8, 8  # 1*16*16*8*8*4B = 65536B = 64KB
    groups = 4
    context_dim = 16
    x = torch.randn(batch, channels, d, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_context_dim():
    # Test with large context dimension
    batch, channels, h, w = 2, 8, 8, 8
    groups = 4
    context_dim = 512
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_performance():
    # Test that forward pass runs in reasonable time for large input
    import time
    batch, channels, h, w = 16, 32, 32, 32  # 16*32*32*32*4B = 2MB
    groups = 8
    context_dim = 32
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    start = time.time()
    codeflash_output = model.forward(x, context); out = codeflash_output
    elapsed = time.time() - start
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import pytest  # used for our unit tests
import torch  # for tensor operations
from src.diffusers.models.unets.unet_kandinsky3 import \
    Kandinsky3ConditionalGroupNorm
from torch import nn

# unit tests

# --------- BASIC TEST CASES ---------

def test_forward_basic_2d():
    # Simple 2D input (batch, channels, height, width)
    batch, channels, height, width = 2, 4, 8, 8
    groups = 2
    context_dim = 5
    x = torch.randn(batch, channels, height, width)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_basic_1d():
    # 1D input (batch, channels, length)
    batch, channels, length = 3, 6, 10
    groups = 3
    context_dim = 7
    x = torch.randn(batch, channels, length)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_basic_3d():
    # 3D input (batch, channels, depth, height, width)
    batch, channels, d, h, w = 1, 8, 4, 4, 4
    groups = 4
    context_dim = 3
    x = torch.randn(batch, channels, d, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_context_broadcasting():
    # Check that context is broadcast correctly for different spatial shapes
    batch, channels, h, w = 2, 4, 12, 7
    groups = 2
    context_dim = 6
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output
    # Output should be different from GroupNorm(x) due to context
    gn = nn.GroupNorm(groups, channels, affine=False)
    normed = gn(x)

# --------- EDGE TEST CASES ---------

def test_forward_single_batch():
    # Single batch
    batch, channels, h, w = 1, 4, 5, 5
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_single_channel():
    # Single channel (should raise error for groups > 1)
    batch, channels, h, w = 2, 1, 8, 8
    groups = 1
    context_dim = 3
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_minimal_spatial():
    # Minimal spatial dimensions (1x1)
    batch, channels, h, w = 2, 2, 1, 1
    groups = 2
    context_dim = 2
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_context_wrong_shape():
    # Context batch size mismatch
    batch, channels, h, w = 2, 4, 4, 4
    groups = 2
    context_dim = 5
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch+1, context_dim)  # Wrong batch size
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    with pytest.raises(RuntimeError):
        codeflash_output = model.forward(x, context); _ = codeflash_output

def test_forward_context_dim_mismatch():
    # Context feature dimension mismatch
    batch, channels, h, w = 2, 4, 4, 4
    groups = 2
    context_dim = 5
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim+1)  # Wrong context_dim
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    with pytest.raises(RuntimeError):
        codeflash_output = model.forward(x, context); _ = codeflash_output

def test_forward_invalid_groups():
    # Invalid group number (not dividing channels)
    batch, channels, h, w = 2, 5, 4, 4
    groups = 2  # 5 not divisible by 2
    context_dim = 3
    with pytest.raises(ValueError):
        _ = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)



def test_forward_empty_tensor():
    # Empty input tensor
    batch, channels, h, w = 0, 4, 4, 4
    groups = 2
    context_dim = 3
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

# --------- LARGE SCALE TEST CASES ---------

def test_forward_large_batch():
    # Large batch size, but under 100MB
    batch, channels, h, w = 128, 8, 16, 16
    groups = 4
    context_dim = 16
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_channels():
    # Large number of channels, but under 100MB
    batch, channels, h, w = 4, 512, 8, 8
    groups = 8
    context_dim = 32
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_spatial():
    # Large spatial dimensions, but under 100MB
    batch, channels, h, w = 2, 16, 128, 128
    groups = 4
    context_dim = 8
    x = torch.randn(batch, channels, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_large_3d():
    # Large 3D input, but under 100MB
    batch, channels, d, h, w = 1, 16, 16, 16, 16
    groups = 4
    context_dim = 8
    x = torch.randn(batch, channels, d, h, w)
    context = torch.randn(batch, context_dim)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output

def test_forward_gradient_flow():
    # Check that gradients flow through both x and context
    batch, channels, h, w = 2, 4, 8, 8
    groups = 2
    context_dim = 4
    x = torch.randn(batch, channels, h, w, requires_grad=True)
    context = torch.randn(batch, context_dim, requires_grad=True)
    model = Kandinsky3ConditionalGroupNorm(groups, channels, context_dim)
    codeflash_output = model.forward(x, context); out = codeflash_output
    loss = out.sum()
    loss.backward()
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-Kandinsky3ConditionalGroupNorm.forward-mb5lqa87 and push.

Codeflash

codeflash-ai bot and others added 4 commits May 26, 2025 21:30
Certainly! Here are the most **important optimizations** for this program, based on the line profiling results.

- The **main bottleneck** is `self.norm(x) * (scale + 1.0) + shift` and the `self.context_mlp(context)` call.  
- The **loop** that repeatedly applies `unsqueeze` to the context tensor is inefficient.
- You can **vectorize** context expansion using `.view` or `.reshape` to match the desired broadcastable shape all at once, rather than unsqueezing in a loop.

The improved code below **removes the loop**, performs shape expansion more efficiently, and should provide speedups for larger batch sizes or channel/image sizes.



### Summary of Optimizations.
- **Removed for-loop with efficient tensor view:** The repetitive `unsqueeze` calls are replaced with a single `view`, which is much faster for matching the broadcasting shape.
- **Precompute and reuse shapes:** Uses `x.dim()` to compute required shape for broadcast once, no per-dimension Python looping.
- **All existing semantics and output shapes preserved.**
- **No unnecessary temp allocations or autograd-op graph buildup.**

This rewrite keeps the function signatures and logic unchanged, but should yield notable performance improvements, especially for large spatial tensors.
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants