diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index b68ea49f9..984977b87 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -769,3 +769,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625, distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +batch_norm,liger,forward,speed,ms,N,hidden size,1024,0.13689599931240082,0.13616639375686646,0.13795199990272522,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2 +batch_norm,liger,forward,speed,ms,N,hidden size,2048,0.26447999477386475,0.26284798979759216,0.2656959891319275,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2 +batch_norm,liger,forward,speed,ms,N,hidden size,4096,0.525056004524231,0.5232831835746765,0.5266559720039368,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2 +batch_norm,liger,forward,speed,ms,N,hidden size,8192,1.05131196975708,1.0489856004714966,1.0533759593963623,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2 +batch_norm,liger,forward,speed,ms,N,hidden size,16384,2.13972806930542,2.1362624168395996,2.143014430999756,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2 +batch_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.041471999138593674,0.0398080013692379,0.042688000947237015,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2 +batch_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.06825599819421768,0.06672000139951706,0.0695360004901886,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2 +batch_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.1191679984331131,0.11868800222873688,0.11961600184440613,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2 +batch_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.21347199380397797,0.21296000480651855,0.21398399770259857,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2 +batch_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.4029119908809662,0.4023999869823456,0.40348801016807556,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2 +batch_norm,liger,full,speed,ms,N,hidden size,1024,0.3394879996776581,0.3375680148601532,0.3413119912147522,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2 +batch_norm,liger,full,speed,ms,N,hidden size,2048,0.6499840021133423,0.6464319825172424,0.6534016132354736,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2 +batch_norm,liger,full,speed,ms,N,hidden size,4096,1.2944639921188354,1.291468858718872,1.297875165939331,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2 +batch_norm,liger,full,speed,ms,N,hidden size,8192,2.5837440490722656,2.579263925552368,2.5880000591278076,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2 +batch_norm,liger,full,speed,ms,N,hidden size,16384,5.309120178222656,5.301023960113525,5.314540863037109,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2 +batch_norm,huggingface,full,speed,ms,N,hidden size,1024,0.08718399703502655,0.08614400029182434,0.08816000074148178,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,speed,ms,N,hidden size,2048,0.14828799664974213,0.14732800424098969,0.14927999675273895,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,speed,ms,N,hidden size,4096,0.25726401805877686,0.25622400641441345,0.2583935856819153,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,speed,ms,N,hidden size,8192,0.4660159945487976,0.46483200788497925,0.4671808183193207,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,speed,ms,N,hidden size,16384,0.880128026008606,0.8787840008735657,0.8814719915390015,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,liger,full,memory,MB,N,hidden size,1024,80.04736328125,80.04736328125,80.04736328125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,liger,full,memory,MB,N,hidden size,2048,160.09423828125,160.09423828125,160.09423828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,liger,full,memory,MB,N,hidden size,4096,320.18798828125,320.18798828125,320.18798828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,liger,full,memory,MB,N,hidden size,8192,640.37548828125,640.37548828125,640.37548828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,liger,full,memory,MB,N,hidden size,16384,1280.75048828125,1280.75048828125,1280.75048828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,memory,MB,N,hidden size,1024,80.05517578125,80.05517578125,80.05517578125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,memory,MB,N,hidden size,2048,160.10986328125,160.10986328125,160.10986328125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,memory,MB,N,hidden size,4096,320.21923828125,320.21923828125,320.21923828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,memory,MB,N,hidden size,8192,640.43798828125,640.43798828125,640.43798828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 +batch_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.87548828125,1280.87548828125,1280.87548828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2 diff --git a/benchmark/scripts/benchmark_batch_norm.py b/benchmark/scripts/benchmark_batch_norm.py new file mode 100644 index 000000000..51a820fa4 --- /dev/null +++ b/benchmark/scripts/benchmark_batch_norm.py @@ -0,0 +1,125 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.batch_norm import LigerBatchNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_speed_batch_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + triton_bn = LigerBatchNorm(hidden_size=N).to(device) + torch_bn = torch.nn.BatchNorm1d(N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_bn(x) + if provider == "huggingface": + return torch_bn(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_batch_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + dtype = input.extra_benchmark_config["dtype"] + M = input.extra_benchmark_config["M"] + eps = input.extra_benchmark_config["eps"] + + x_shape = (M, N) + + triton_bn = LigerBatchNorm(hidden_size=N).to(device) + torch_bn = torch.nn.BatchNorm1d(N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_bn(x) + if provider == "huggingface": + return torch_bn(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "batch_norm", + "x_name": "N", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 15)], # Range of hidden size values + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_batch_norm, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_batch_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/src/liger_kernel/ops/batch_norm.py b/src/liger_kernel/ops/batch_norm.py new file mode 100644 index 000000000..876c9a37a --- /dev/null +++ b/src/liger_kernel/ops/batch_norm.py @@ -0,0 +1,228 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous + +# Import rsqrt function based on Triton version +if compare_version("triton", operator.ge, "3.0.0"): + try: + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _batch_norm_forward_kernel( + Y_ptr, # pointer to output, shape (N, C) + Y_row_stride, # stride between rows in output (usually C) + X_ptr, # pointer to input, shape (N, C) + X_row_stride, # stride between rows in input + gamma_ptr, # pointer to scale, shape (C,) + beta_ptr, # pointer to bias, shape (C,) + mean_ptr, # pointer to mean, shape (C,) + rstd_ptr, # pointer to rstd, shape (C,) + n_rows: tl.constexpr, # batch size N + n_channels: tl.constexpr, # feature dim C + eps, # small constant + BLOCK_SIZE: tl.constexpr, # the number of rows processed in each block (for vectorization) +): + """ + BatchNorm Forward kernel + Each program instance handles one channel (i.e., feature c), + and performs reduction on all batch elements (rows) of that channel to compute mean, variance, and produce normalized output. + """ + # Each program instance processes one channel + channel_idx = tl.program_id(0) + if channel_idx >= n_channels: + return + + # --- First pass: compute mean and variance --- + sum_val = tl.zeros([], dtype=tl.float32) + sum_sq_val = tl.zeros([], dtype=tl.float32) + for row_offset in range(0, n_rows, BLOCK_SIZE): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < (n_rows - row_offset) + # Element address: X_ptr + (row_offset + offset)*X_row_stride + channel_idx + base_ptr = X_ptr + (row_offset * X_row_stride) + channel_idx + x_block = tl.load(base_ptr + offsets * X_row_stride, mask=mask, other=0.0) + sum_val += tl.sum(x_block, axis=0) + sum_sq_val += tl.sum(x_block * x_block, axis=0) + mean = sum_val / n_rows + var = sum_sq_val / n_rows - mean * mean + rstd = rsqrt(var + eps) + tl.store(mean_ptr + channel_idx, mean) + tl.store(rstd_ptr + channel_idx, rstd) + + # --- Second pass: compute normalized output --- + gamma_val = tl.load(gamma_ptr + channel_idx) + beta_val = tl.load(beta_ptr + channel_idx) + for row_offset in range(0, n_rows, BLOCK_SIZE): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < (n_rows - row_offset) + base_ptr_in = X_ptr + (row_offset * X_row_stride) + channel_idx + x_block = tl.load(base_ptr_in + offsets * X_row_stride, mask=mask, other=0.0) + # Normalization: xhat = (x - mean)*rstd + y_val = gamma_val * ((x_block - mean) * rstd) + beta_val + base_ptr_out = Y_ptr + (row_offset * Y_row_stride) + channel_idx + tl.store(base_ptr_out + offsets * Y_row_stride, y_val, mask=mask) + + +@triton.jit +def _batch_norm_backward_kernel( + X_ptr, # pointer to input X, shape (N, C) + dY_ptr, # pointer to upstream gradient dY, shape (N, C) + DX_ptr, # pointer to output gradient dX, shape (N, C) + gamma_ptr, # pointer to scale, shape (C,) + mean_ptr, # pointer to mean, shape (C,) + rstd_ptr, # pointer to rstd, shape (C,) + dgamma_ptr, # pointer to dgamma, shape (C,) + dbeta_ptr, # pointer to dbeta, shape (C,) + n_rows: tl.constexpr, # batch size + n_channels: tl.constexpr, # feature dim C + BLOCK_SIZE: tl.constexpr, # the number of rows processed in each block + stride_x, # stride between rows in X (usually C) + stride_dy, # stride between rows in dY + stride_dx, # stride between rows in dX +): + """ + BatchNorm Backward kernel + Each program instance processes one channel, performing two passes over the batch: + The first pass computes dgamma and dbeta; + The second pass computes dX. + """ + channel_idx = tl.program_id(0) + if channel_idx >= n_channels: + return + + gamma_val = tl.load(gamma_ptr + channel_idx) + mean_val = tl.load(mean_ptr + channel_idx) + rstd_val = tl.load(rstd_ptr + channel_idx) + + # --- First pass: compute dgamma and dbeta --- + dgamma_acc = tl.zeros([], dtype=tl.float32) + dbeta_acc = tl.zeros([], dtype=tl.float32) + for row_offset in range(0, n_rows, BLOCK_SIZE): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < (n_rows - row_offset) + base_ptr_dy = dY_ptr + (row_offset * stride_dy) + channel_idx + dy_block = tl.load(base_ptr_dy + offsets * stride_dy, mask=mask, other=0.0) + base_ptr_x = X_ptr + (row_offset * stride_x) + channel_idx + x_block = tl.load(base_ptr_x + offsets * stride_x, mask=mask, other=0.0) + # Compute xhat = (x - mean)*rstd + xhat_block = (x_block - mean_val) * rstd_val + dgamma_acc += tl.sum(dy_block * xhat_block, axis=0) + dbeta_acc += tl.sum(dy_block, axis=0) + tl.store(dgamma_ptr + channel_idx, dgamma_acc) + tl.store(dbeta_ptr + channel_idx, dbeta_acc) + + # --- Second pass: compute dX --- + # Note: since n_rows is constexpr, we can convert it to float for division + N_float = float(n_rows) + for row_offset in range(0, n_rows, BLOCK_SIZE): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < (n_rows - row_offset) + base_ptr_dy = dY_ptr + (row_offset * stride_dy) + channel_idx + dy_block = tl.load(base_ptr_dy + offsets * stride_dy, mask=mask, other=0.0) + base_ptr_x = X_ptr + (row_offset * stride_x) + channel_idx + x_block = tl.load(base_ptr_x + offsets * stride_x, mask=mask, other=0.0) + xhat_block = (x_block - mean_val) * rstd_val + # dx = gamma * rstd * [dy - dbeta/N - xhat*(dgamma/N)] + dx_block = gamma_val * rstd_val * (dy_block - (dbeta_acc / N_float) - xhat_block * (dgamma_acc / N_float)) + base_ptr_dx = DX_ptr + (row_offset * stride_dx) + channel_idx + tl.store(base_ptr_dx + offsets * stride_dx, dx_block, mask=mask) + + +def batch_norm_forward(X, gamma, beta, eps): + """ + Forward pass: + X: shape (N, C) + gamma, beta: shape (C,) + Returns: + Y, as well as intermediate variables X, Mean, RSTD for backward pass + """ + shape = X.shape + assert len(shape) == 2, "Currently, BatchNorm only supports 2D input (N, C)" + n_rows, n_channels = shape + # Choose BLOCK_SIZE based on the dimension of the batch + BLOCK_SIZE, num_warps = calculate_settings(n_rows) + Y = torch.empty((n_rows, n_channels), dtype=X.dtype, device=X.device) + # Mean and rstd saved per channel + Mean = torch.empty(n_channels, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_channels, dtype=X.dtype, device=X.device) + # Check gamma shape + assert gamma.shape[0] == n_channels, "gamma dimension should match input feature dimension" + grid = (n_channels,) + _batch_norm_forward_kernel[grid]( + Y, + Y.stride(0), + X, + X.stride(0), + gamma, + beta, + Mean, + RSTD, + n_rows, + n_channels, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return Y, X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def batch_norm_backward(dY, X, gamma, beta, Mean, RSTD): + """ + Backward pass: + dY: upstream gradient, shape (N, C) + Returns: + dX, dgamma, dbeta + """ + shape = dY.shape + assert len(shape) == 2, "Currently, BatchNorm only supports 2D input (N, C)" + n_rows, n_channels = shape + DX = torch.empty((n_rows, n_channels), dtype=X.dtype, device=X.device) + # dgamma, dbeta are both (C,) + dgamma = torch.empty(n_channels, dtype=gamma.dtype, device=gamma.device) + dbeta = torch.empty(n_channels, dtype=gamma.dtype, device=gamma.device) + BLOCK_SIZE, num_warps = calculate_settings(n_rows) + grid = (n_channels,) + _batch_norm_backward_kernel[grid]( + X, + dY, + DX, + gamma, + Mean, + RSTD, + dgamma, + dbeta, + n_rows, + n_channels, + BLOCK_SIZE=BLOCK_SIZE, + stride_x=X.stride(0), + stride_dy=dY.stride(0), + stride_dx=DX.stride(0), + ) + return DX, dgamma, dbeta + + +class LigerBatchNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, gamma, beta, eps): + Y, X_saved, Mean, RSTD, BLOCK_SIZE, num_warps = batch_norm_forward(X, gamma, beta, eps) + ctx.save_for_backward(X_saved, gamma, beta, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, gamma, beta, Mean, RSTD = ctx.saved_tensors + DX, dgamma, dbeta = batch_norm_backward(dY, X, gamma, beta, Mean, RSTD) + return DX, dgamma, dbeta, None + diff --git a/src/liger_kernel/transformers/batch_norm.py b/src/liger_kernel/transformers/batch_norm.py new file mode 100644 index 000000000..2d03da11d --- /dev/null +++ b/src/liger_kernel/transformers/batch_norm.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops.batch_norm import LigerBatchNormFunction + + +class LigerBatchNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, bias=True, init_fn="ones"): + """ + Initialize the LigerBatchNorm class. + + Arguments: + hidden_size (int): The size of the input features (i.e., the C dimension). + eps (float): Small constant to prevent division by zero. + bias (bool): Whether to use the bias term. + init_fn (str): Initialization method for the weight, either "ones" or "zeros". + """ + super().__init__() + + # Ensure init_fn parameter is valid + assert init_fn in ["ones", "zeros"], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + + self.hidden_size = hidden_size + self.eps = eps + + # Initialize weight and bias parameters + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size) if not bias else torch.randn(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Forward pass. + + Arguments: + hidden_states (torch.Tensor): The input tensor, shape (N, C), where N is the batch size and C is the feature dimension. + + Returns: + torch.Tensor: The normalized output tensor. + """ + return LigerBatchNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon) + + def extra_repr(self): + """ + Returns additional information about the class, typically used to print more details when displaying the model. + """ + return f"{self.hidden_size}, eps={self.eps}" diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index c2f51e952..69a5e58c2 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -1,5 +1,6 @@ from typing import Optional +from liger_kernel.ops.batch_norm import LigerBatchNormFunction from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction @@ -175,3 +176,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def liger_swiglu(a, b): return LigerSiLUMulFunction.apply(a, b) + + +def liger_batch_norm(X, gamma, beta, eps): + return LigerBatchNormFunction.apply(X, gamma, beta, eps) diff --git a/test/transformers/test_batch_norm.py b/test/transformers/test_batch_norm.py new file mode 100644 index 000000000..f0fb295f6 --- /dev/null +++ b/test/transformers/test_batch_norm.py @@ -0,0 +1,103 @@ +import pytest +import torch + +from liger_kernel.ops.batch_norm import LigerBatchNormFunction +from liger_kernel.transformers.batch_norm import LigerBatchNorm +from liger_kernel.transformers.functional import liger_batch_norm +from liger_kernel.utils import infer_device + +device = infer_device() + + +# Test for LigerBatchNorm +@pytest.mark.parametrize( + "batch_size, hidden_size", + [ + (3, 96), + (4, 128), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-1, 1e-1), + ], +) +def test_liger_batch_norm(batch_size, hidden_size, dtype, atol, rtol): + torch.manual_seed(0) + + # Modify the input shape to (N, C) + x = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + + liger_x = x.clone().requires_grad_(True) + torch_x = x.clone().requires_grad_(True) + + liger_bn = LigerBatchNorm(hidden_size, eps=1e-6).to(dtype).to(device) + torch_bn = torch.nn.BatchNorm1d(hidden_size, eps=1e-6).to(dtype).to(device) + + with torch.no_grad(): + torch_bn.weight.copy_(liger_bn.weight) + torch_bn.bias.copy_(liger_bn.bias) + + liger_output = liger_bn(liger_x) + torch_output = torch_bn(torch_x) + + assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) + + grad_output = torch.randn_like(x) + liger_output.backward(grad_output, retain_graph=True) + torch_output.backward(grad_output, retain_graph=True) + + assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) + assert torch.allclose(liger_bn.weight.grad, torch_bn.weight.grad, atol=atol, rtol=rtol) + assert torch.allclose(liger_bn.bias.grad, torch_bn.bias.grad, atol=atol, rtol=rtol) + + +# Test for LigerBatchNormFunction +@pytest.mark.parametrize( + "batch_size, hidden_size", + [ + (3, 96), + (4, 128), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + ], +) +def test_liger_batch_norm_functional(hidden_size, batch_size, dtype, atol, rtol): + torch.manual_seed(0) + + # Modify the input shape to (N, C) + input = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + + x1 = input.clone().requires_grad_(True) + x2 = input.clone().requires_grad_(True) + + w = torch.randn(hidden_size, device=device, dtype=dtype) + + w1 = w.clone().requires_grad_(True) + w2 = w.clone().requires_grad_(True) + + b = torch.randn(hidden_size, device=device, dtype=dtype) + + b1 = b.clone().requires_grad_(True) + b2 = b.clone().requires_grad_(True) + + # Using LigerBatchNorm function + y1 = liger_batch_norm(X=x1, gamma=w1, beta=b1, eps=1e-6) + # Using LigerBatchNormFunction directly + y2 = LigerBatchNormFunction.apply(x2, w2, b2, 1e-6) + + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) + + grad_output = torch.randn_like(y2) + + y1.backward(grad_output, retain_graph=True) + y2.backward(grad_output, retain_graph=True) + + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + assert torch.allclose(w1.grad, w2.grad, atol=atol, rtol=rtol) + assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)