Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Naive Run Compressed Pt. 2 #62

Merged
merged 14 commits into from
Aug 30, 2024
Merged

Naive Run Compressed Pt. 2 #62

merged 14 commits into from
Aug 30, 2024

Conversation

Satrat
Copy link
Contributor

@Satrat Satrat commented Aug 6, 2024

SUMMARY:
Follow up PR to neuralmagic/compressed-tensors#109, enables loading compressed models into SparseAutoModel, each quantized layer is decompressed on the forward pass.

  • Adds run_compressed argument to SparseAutoModel
  • Removes structure initialization in QuantizationModifier, its no longer needed as we do this on load

TEST PLAN:
Manual example, will follow up with integration tests once the compressed-tensor branch merges

from transformers import AutoTokenizer
from llmcompressor.transformers import SparseAutoModelForCausalLM
import torch

model_dir = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-compressed"
model = SparseAutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16, device_map="auto", run_compressed=True)

tokenizer = AutoTokenizer.from_pretrained(model_dir)
sample_input = ["I love 8 bit quantization because"]
inputs = tokenizer(sample_input, return_tensors="pt").to(model.device)
generated_ids = model.generate(**inputs, max_length=50)
outputs = tokenizer.batch_decode(generated_ids)
print(outputs)

Output: ["<|begin_of_text|>I love 8 bit quantization because it's a great way to reduce the precision of floating point numbers and make them more compact. It's also a great way to make them more robust against noise and quantization errors.\n\nHere's a simple example of how you can implement 8 bit quantization in Python:\n```\nimport numpy as np\n\ndef quantize(x, bits=8):\n x_min = np.min(x)\n x_max = np.max(x)\n scale = "]

Runtime: ~26sec on two A4000s, ~65sec for the non-compressed version

@Satrat Satrat changed the title [Don't Merge Yet] Naive Run Compressed Pt. 2 Naive Run Compressed Pt. 2 Aug 7, 2024
@Satrat Satrat marked this pull request as ready for review August 12, 2024 15:53
bfineran
bfineran previously approved these changes Aug 19, 2024
@Satrat Satrat merged commit 8187914 into main Aug 30, 2024
7 checks passed
markmc pushed a commit to markmc/llm-compressor that referenced this pull request Nov 13, 2024
* small fixes

* initial commit

* bug fixes

* cleanup

* clarity comments

* clean up compression classes

* fixing zero point issues

* comment for hack

* update quant check

* cleanup fp8 dtypes

* cleanup

* clean up observer

* dtype fix

* docstrings

* fixes after rebase

* test fixes

* style

* get rid of broken segment

* fix broken code
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants