Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

torch custom_op support: norm #552

Merged
merged 1 commit into from
Oct 24, 2024

Conversation

abcdabcd987
Copy link
Member

Add torch custom_op (aka, torch library, torch.compile) support for norm.py. It should be a no-op for PyTorch < 2.4.

Testing is done by torch.compile -- as we expect the custom_op marks can isolate out our kernels during torch.compile. To avoid changes to tests, I introduced some magic that replaces the kernels with a torch.compile-ed version. For example, to run with/without torch.compile:

# With torch.compile
FLASHINFER_TEST_TORCH_COMPILE=1 pytest -svx tests/test_norm.py

# Without torch.compile
pytest -svx tests/test_norm.py

If this PR looks good, I'll add it to more kernels.

@abcdabcd987 abcdabcd987 requested a review from yzh119 October 23, 2024 20:50
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @abcdabcd987 !

@yzh119 yzh119 merged commit f6e0010 into flashinfer-ai:main Oct 24, 2024
yzh119 pushed a commit that referenced this pull request Oct 25, 2024
Follow up of #552. This PR adds torch library annotation to all
FlashInfer kernels so that torch.compile can recognize the kernels. Most
changes are tedious.

I manually ran subsets of pytest test cases when I made these changes,
but since there are too many of them and also some of them didn't pass
even before I made the change, I cannot guarantee it's all working. To
run tests with torch.compile, pass `FLASHINFER_TEST_TORCH_COMPILE=1`
env.

```bash
# With torch.compile
FLASHINFER_TEST_TORCH_COMPILE=1 pytest -svx tests/test_norm.py

# Without torch.compile
pytest -svx tests/test_norm.py
```

Notable changes:
* For the prefill and decode pybind, it used to return
`Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]` depending on
`return_lse`. This causes trouble for `torch.compile`. I changed the
pybind interface to accept a `maybe_lse: Optional[torch.Tensor]` and
only return one tensor. The allocation of the lse tensor is moved to
Python side. The Python API does not change.
* `chain_speculative_sampling` pybind: Move the allocation of `accepted`
and `emitted` from C++ to Python. This is because `torch.compile`
doesn't like returning input tensor as output tensor. The Python API
does not change.

Piggyback changes:
* `BatchPrefillWithRaggedKVCacheWrapper.plan`: Bugfix qo_indptr not on
CPU
* `merge_state`: Fix typo in docs
* Change `run_return_lse(...)` to `run(..., return_lse=True)` because
torch.compile does not recognize `functools.partial`.
* In tests, change `flashinfer.xxx()` to `flashinfer.<module>.xxx()` so
that the monkeypatch works.

Unsupported for torch.compile:
* `flashinfer.quantization.segment_packbits`: Because it's data
dependent.

Untouched:
* `sparse.py`: Tests didn't pass beforehand, so I skiped this. Also, it
doesn't seem like need custom_op annotations, as it does not have CUDA
kernels.

Failed test cases:
* batch_decode non contiguous kv:
`test_batch_decode_with_paged_kv_cache[False-kv_dtype0-q_dtype0-True-0.0-NONE-NHD-128-4-4-1-54-12]`
@abcdabcd987 abcdabcd987 mentioned this pull request Oct 30, 2024
yzh119 pushed a commit that referenced this pull request Oct 31, 2024
Here's the reason why docs fail to build after #552: As specified in
`conf.py`, Sphinx mocks `torch`. The mock makes the following predicate
behave badly: `TorchVersion(torch_version) < TorchVersion("2.4")`.

The fix is to explicitly pass in an env var indicating docs building.

Also changing the way that `prefill.py` imports compiled `_kernels` so
that it's consistent with other files.
yzh119 added a commit that referenced this pull request Dec 17, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.2.0](v0.1.6...v0.2.0)
(2024-12-17)

[Release
Blog](https://flashinfer.ai/2024/12/16/flashinfer-v02-release.html).

### Features

* add `rotary_dim` argument to rope APIs for partial apply rope
([#599](#599))
([eb9bc71](eb9bc71))
* add a `use_softmax` field in variant class
([#533](#533))
([d81af97](d81af97))
* add an option `non_blocking` to plan function
([#622](#622))
([560af6f](560af6f))
* add gemma_rmsnorm and gemma_fused_add_rmsnorm
([#477](#477))
([1a6b17e](1a6b17e))
* add group size 3 to GQA decode dispatch
([#558](#558))
([6227562](6227562))
* add JIT compilation support for FA3 templates
([#672](#672))
([d4e8d79](d4e8d79))
* allow the cascade kernels to be executed using varying sequence
lenghts ([#627](#627))
([92ac440](92ac440))
* CUDAGraph compatibility of multi-level cascade inference APIs
([#586](#586))
([2332e8a](2332e8a))
* fix the maximal grid dimension in prefill planning with CUDA graphs
([#639](#639))
([86ca89a](86ca89a))
* improve the precision of the FusedAddRMSNormKernel function
([#587](#587))
([c7dc921](c7dc921))
* JIT compilation
([#507](#507))
([3613a5b](3613a5b))
* modify group-gemm stage number
([#497](#497))
([52dab1d](52dab1d))
* non-contiguous query with paged kv cache
([#553](#553))
([89f2c4a](89f2c4a))
* pass a dynamic token count to the cascade kernels
([#635](#635))
([5fe9f7d](5fe9f7d))
* simplify prefill JIT compilation
([#605](#605))
([fe4f898](fe4f898))
* specify gemm backend
([#648](#648))
([0cc1a51](0cc1a51))
* support cached cos/sin in rope APIs
([#585](#585))
([83e541d](83e541d))
* support huggingface transformer style rope interface
([#568](#568))
([4f40420](4f40420))
* support sm90 cutlass group gemm
([#509](#509))
([794bdda](794bdda))
* torch custom_op fix for rope
([#569](#569))
([3e104bc](3e104bc))
* torch custom_op support: norm
([#552](#552))
([f6e0010](f6e0010))
* torch.compile and custom_op support
([#554](#554))
([9bf916f](9bf916f))
* warmup for jit kernel tests
([#629](#629))
([8f5f349](8f5f349))


### Bug Fixes

* AOT compiler flags on non-sm90
([#522](#522))
([0aa4726](0aa4726))
* batch decode kernel redundant store output to gmem
([#505](#505))
([90e42a7](90e42a7))
* compatible with torch 2.2
([#478](#478))
([ac41d1b](ac41d1b))
* #452
([b53a46f](b53a46f))
* remove redundant load
([#495](#495))
([2de16b0](2de16b0))
* update bmm fp8 test
([#487](#487))
([45eac04](45eac04))


### Performance Improvements

* accelerate JIT compilation speed
([#618](#618))
([eaf73fd](eaf73fd))
* Dense and sparse customizable flashattention-3 template
([#667](#667))
([51236c9](51236c9))
* fix prefill kernel performance degradation (step 1)
([#602](#602))
([595cf60](595cf60))
* fix the performance issue of `append_paged_kv_cache`
([#588](#588))
([e15f7c9](e15f7c9))
* improve parallelism in RoPE with pos_ids
([#609](#609))
([ff05155](ff05155))
* improve plan performance by using non-blocking memcpy
([#547](#547))
([41ebe6d](41ebe6d))
* reduce the read and write of shared memory in the
FusedAddRMSNormKernel
([#592](#592))
([2043ca2](2043ca2))
* reduce total_num_tiles_q by one
([#644](#644))
([553ace5](553ace5))
* remove unnecessary contiguous operation in block sparse attention
([#561](#561))
([7a7ad46](7a7ad46))
* speedup jit compilation of prefill attention kernels
([#632](#632))
([a059586](a059586))
* use cuda-core implemention for io-bound block-sparse attention
([#560](#560))
([3fbf028](3fbf028))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants