Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

commit act quant for conditional ffn #156

Draft
wants to merge 3 commits into
base: mlperf-mixtral
Choose a base branch
from
Draft

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Jul 23, 2024

No description provided.

wang2yn84 and others added 3 commits July 23, 2024 00:23
* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested.

* Fixed the test_model_impl for llama, but test_llama_e2e is still failing.

* Adds lazy_cache_update and restructure the cache flags.

* Disable all the prints. Fix create engine.

* Fix typos and minor errors.

* Fixes create engine.

* Adds new_cache_stacked and fixes cache update.

* Fix cache update when new_cach_stacked is False.

* Fix the cache manager and make unit tests pass except for 1.

* Updates the exportable model to return cache.

* Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention.

* Try to use shard_map for cache update.

* Fix update single cache line in cache.finalize()

* Adds int8 support.

* Int8 left aligned lazy cache update working, performance still not good enough.

* Fix the stacked cache introduced in the previous couple of commits.

* Put original ragged attention back.

* Add the original ragged attention kernel.

* Fixes the bf16/int8 cache stack.

* Fix int8 stacked cache insertion in engine and finalization.

* Fixes int8 with lazy cache update.

* Updates the int8 test.

* Fix the int8 ragged attention output sharding.

* Fix group query attention broadcasting issue.

* Fix shard map input issue. Variables not listed as inputs are freezed into jit function.

* Fix the flash attention mask shape; Fix the update single cache line quant version

* Adds the kv cache test.

* Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test.

* Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust;

* Adds lazy cache update with generate cache stacked new cache unstacked for performance validation.

* Fix the shard map sharding for stacked generate cache and unstacked new cache.

* Using Jax API to slicing instead of Pytorch index slicing.

* Adds stacked cache support in ragged attention reference kernel.

* Adds stacked cache support for the modified ragged kernel.

* Llama2 70b int8 optimization done. Output not correct yet.

* Remove testing temp output files.

* Fix the llama 70b output accuracy resulting from gqa.

* Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc.

* Fix the pallas kernel OOB issue

* Fix tests; Fix lint issues;

* Fix the interactive script.

* Fix lint errors.

* Fix errors.

* Fix the comments.

* Fix based on comments; Fix all the unit tests.

* Fix the remaining pylint errors.

* Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run.

* Fix all the lint errors.

* Remove the deps/JetStream changes.

* Fix merge errors, fix lint errors.
init params

add other scripts

debug accuracy
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants