Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[issue tracker] make quantization compatible with dynamo dynamic shape #9234

Closed
1 task done
youkaichao opened this issue Oct 10, 2024 · 5 comments · Fixed by #9299
Closed
1 task done

[issue tracker] make quantization compatible with dynamo dynamic shape #9234

youkaichao opened this issue Oct 10, 2024 · 5 comments · Fixed by #9299
Labels

Comments

@youkaichao
Copy link
Member

Anything you want to discuss about vllm.

here is a simple demo code:

import torch
from torch.utils.cpp_extension import load_inline

custom_library = torch.library.Library("custom", "DEF")
custom_library.define("add_cpp(Tensor x, int y) -> Tensor")

cpp_source = """
#include <torch/extension.h>

torch::Tensor custom_add(torch::Tensor x, int64_t y) {
    return x + y;
}

TORCH_LIBRARY_IMPL(custom, CPU, m) {
    m.impl("add_cpp", custom_add);
}
"""

custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

import torch

@torch.library.custom_op("custom::add_py", mutates_args=[])
def add_py(x: torch.Tensor, y: int) -> torch.Tensor:
    return x + y

@add_py.register_fake
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    # return torch.ops.custom.add_py(x, x.shape[0]) # passes
    return torch.ops.custom.add_cpp(x, x.shape[0]) # errors with `Not all values of RelaxedUnspecConstraint(L['x'].size()[0]) are valid because L['x'].size()[0] was inferred to be a constant (2).`

x = torch.ones(2, 4)
torch._dynamo.mark_dynamic(x, 0)
print(f(x)[0])

when we register the custom op from c++ side, dynamic shape will be directly specialized to an integer, and fail.
when we register the custom op from Python side, dynamic shape works as expected.

we should change the way we register quantization as custom ops, from c++ side to python side.

there's also one complicated object

class scalar_types:
that appears in the custom op parameter :

vllm/vllm/_custom_ops.py

Lines 315 to 321 in f3a507f

@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

we can use strings to represent the type, and look up the actual object to pass into the c++ function.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao
Copy link
Member Author

cc @bnellnm

@youkaichao
Copy link
Member Author

there's a related issue in pytorch pytorch/pytorch#112883 , and the comment seems to be that pytorch will not fix it in the near future.

I tested it in pytorch nightly (2.6.0.dev20241004) , it still has this problem.

@bnellnm
Copy link
Contributor

bnellnm commented Oct 10, 2024

I was able to workaround the problem by modifying the schemas to take SymInts. I'll look into the scalar_type issue.

import torch
from torch.utils.cpp_extension import load_inline

custom_library = torch.library.Library("custom", "DEF")
custom_library.define("add_cpp(Tensor x, SymInt y) -> Tensor")

cpp_source = """                                                                                                                                    
#include <torch/extension.h>                                                                                                                        
                                                                                                                                                    
torch::Tensor custom_add(torch::Tensor x, int64_t y) {                                                                                              
    return x + y;                                                                                                                                   
}                                                                                                                                                   
                                                                                                                                                    
TORCH_LIBRARY_IMPL(custom, CPU, m) {                                                                                                                
    m.impl("add_cpp", custom_add);                                                                                                                  
}                                                                                                                                                   
"""

custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: torch.SymInt) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

import torch

@torch.library.custom_op("custom::add_py", mutates_args=[])
def add_py(x: torch.Tensor, y: int) -> torch.Tensor:
    return x + y

@add_py.register_fake
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    # return torch.ops.custom.add_py(x, x.shape[0]) # passes                                                                                        
    return torch.ops.custom.add_cpp(x, x.shape[0]) # errors with `Not all values of RelaxedUnspecConstraint(L['x'].size()[0]) are valid because L['x'].size()[0] was inferred to be a constant (2).`                                                                                                   

x = torch.ones(2, 4)
torch._dynamo.mark_dynamic(x, 0)
print(f(x)[0])

@bnellnm
Copy link
Contributor

bnellnm commented Oct 10, 2024

This also works.

custom_library = torch.library.Library("custom", "DEF")

cpp_source = """                                                                                                                                    
#include <torch/extension.h>                                                                                                                        
                                                                                                                                                    
torch::Tensor custom_add(torch::Tensor x, int64_t y) {                                                                                              
    return x + y;                                                                                                                                   
}                                                                                                                                                   
                                                                                                                                                    
TORCH_LIBRARY_FRAGMENT(custom, m)                                                                                                                   
{                                                                                                                                                   
    m.def("add_cpp(Tensor x, SymInt y) -> Tensor");                                                                                                 
    m.impl("add_cpp", torch::kCPU, custom_add);                                                                                                     
}                                                                                                                                                   
"""
            
custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: torch.SymInt) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

I think the ScalarType problem is orthogonal to the SymInt problem.

@youkaichao
Copy link
Member Author

I think the ScalarType problem is orthogonal to the SymInt problem.

yes, they are two separate problems. for dynamo dynamic shape to understand quantization ops, both problems need to be solved.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants