Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

INT8 quantization support #45

Open
casper-hansen opened this issue Sep 11, 2023 · 3 comments
Open

INT8 quantization support #45

casper-hansen opened this issue Sep 11, 2023 · 3 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@casper-hansen
Copy link
Owner

casper-hansen commented Sep 11, 2023

The motivation for INT8 is to keep even more accuracy while still getting some gains on inference speed. I experimented with implementing dequantization for INT8 and ultimately need more work on this before it will be usable.

Edit: Implement SmoothQuant instead. Here is a fork of SmoothQuant that supports LLaMa models. Integrate this into AutoAWQ. https://github.com/AniZpZ/smoothquant/tree/llama-dev

__device__ uint8_t dequantize_s8_to_fp16x2(uint32_t const& source)
{
    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L54
    uint8_t result;

    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);

    // Casper: Original was 0x64646464 = {1124, 1124}
    // Optimize to 0x64806480 because divisible by 8, 16, 32, 64, 128
    // NOTE: Test out {1280, 1280} since it's also divisible by 256
    static constexpr uint32_t mask_for_elt_01     = 0x5250;
    static constexpr uint32_t mask_for_elt_23     = 0x5351;
    static constexpr uint32_t start_byte_for_fp16 = 0x64806480; 
    asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
    asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));

    // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
    // Casper 0x64806480 = {1152, 1152}
    static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; 
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
}
@casper-hansen casper-hansen added enhancement New feature or request help wanted Extra attention is needed labels Sep 11, 2023
@casper-hansen casper-hansen mentioned this issue Sep 11, 2023
30 tasks
@yunfeng-scale
Copy link

How would you compare this with 8 bit bitsandbytes? i think bitsandbytes have minimal performance loss

@casper-hansen
Copy link
Owner Author

How would you compare this with 8 bit bitsandbytes? i think bitsandbytes have minimal performance loss

It is not implemented yet, so I cannot speak to it

@casper-hansen
Copy link
Owner Author

#71 is working on INT8 support. Still things left to be implemented.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants