-
Notifications
You must be signed in to change notification settings - Fork 363
Add fp4 support #3532
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Add fp4 support #3532
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-25 17:51:42.835275+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-25 17:52:07.703670+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+
def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-25 17:51:42.834275+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-25 17:52:08.266101+00:00
@@ -272,17 +272,23 @@
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)
- if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.float16 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP16)
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)
- if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.fp8 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP8)
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-25 17:51:42.836275+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-25 17:52:08.286663+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os
+
def permute(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2025-05-25 17:51:42.863275+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2025-05-25 17:52:13.684130+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version
assertions = unittest.TestCase()
import os
+
@pytest.mark.unit
def test_resnet18(ir):
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -208,10 +209,11 @@
)
@pytest.mark.unit
def test_base_fp4(ir):
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
+
dtype = torch.float16
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
@@ -227,21 +229,20 @@
"""Simple calibration function for testing."""
model(input_tensor)
input_tensor = torch.ones(128, 64, dtype=dtype).cuda()
-
model = SimpleNetwork().eval().cuda()
model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
print(f"lan added amax: {input_tensor.abs().amax()=}")
print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
expected_output = model(input_tensor)
- print(f"lan added model input: {input_tensor=}")
+ print(f"lan added model input: {input_tensor=}")
print(f"lan added model weight: {model.linear1.weight=}")
print(f"lan added model bias: {model.linear1.bias=}")
-
+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has qdq nodes at this point
with torch.no_grad():
with export_torch_mode():
@@ -269,15 +270,21 @@
print("lan added disable_gemm is set, compring result with weights")
expected_output = model.linear1.weight
else:
print("lan added disable_gemm is not set, compring result with pytorch")
- print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
- print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
+ print(
+ f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
+ )
+ print(
+ f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
+ )
abs_diff = torch.abs(expected_output - outputs_trt)
- print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+ print(
+ f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+ )
print(f"lan added abs_diff: {abs_diff=}")
assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)
@unittest.skipIf(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-28 16:06:33.359691+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-28 16:06:58.870610+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+
def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-28 16:06:33.358691+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-28 16:06:59.425511+00:00
@@ -272,17 +272,23 @@
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)
- if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.float16 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP16)
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)
- if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.fp8 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP8)
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-28 16:06:33.360691+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-28 16:06:59.489610+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os
+
def permute(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-06-05 16:38:29.810386+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-06-05 16:38:53.407521+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+
def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-05 16:38:29.809386+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-06-05 16:38:53.955223+00:00
@@ -272,17 +272,23 @@
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)
- if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.float16 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP16)
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)
- if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.fp8 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP8)
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-06-05 16:38:29.811386+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-06-05 16:38:54.015692+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os
+
def permute(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
with unset_fake_temporarily(): | ||
axis = -1 | ||
global_scale = _calculate_global_scale(ctx, name, amax) | ||
if ".weight_quantizer" in name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really the way we determine which quantization scheme to use? What metadata is associated with these nodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below are all the information I can get from the aten ops,
I don't see any better way to know this is the weight or input
@dynamo_tensorrt_converter(
torch.ops.tensorrt.dynamic_block_quantize_op.default,
supports_dynamic_shapes=True,
)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an end to end example for FP4 using modelopt and tensorrt
There is a simple linear end to end FP4 example in this PR. |
Description
Add fp4 support
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: