Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
rotary_dim
argument to rope APIs for partial apply rope (#…
…599) This PR implements the final piece of #530 , so that we can partially apply rotary embedding to first head dimensions instead of entire head dimensions. We also add a simple benchmark for RoPE, below is the result on H100: ```python batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 23us, throughput: 0.876GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 26us, throughput: 0.801GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.735GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.639GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 31us, throughput: 672.889GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 32us, throughput: 662.972GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 14.559GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 14.435GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 37us, throughput: 1339.450GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 37us, throughput: 1340.399GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2696.563GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 148us, throughput: 2689.104GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 74.186GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 74.452GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2350.830GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2359.814GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 717us, throughput: 2895.389GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 718us, throughput: 2891.385GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.449GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.646GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2576.101GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 130us, throughput: 2582.447GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.154GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 925us, throughput: 2903.484GB/s ```
- Loading branch information