diff --git a/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py index e3fbda9b2..639d5300c 100644 --- a/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py +++ b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py @@ -56,14 +56,14 @@ def matmul_rhs_group_quant( raise ValueError("Input shapes not supported.") -brevitas_lib = torch.library.Library("brevitas", "DEF") +brevitas_lib = torch.library.Library("quant", "DEF") brevitas_lib.define( "matmul_rhs_group_quant(Tensor lhs, Tensor rhs, Tensor rhs_scale, Tensor rhs_zero_point, int rhs_bit_width, int rhs_group_size) -> Tensor" ) brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant) -def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: +def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: if len(lhs) == 3 and len(rhs) == 2: return [lhs[0], lhs[1], rhs[0]] elif len(lhs) == 2 and len(rhs) == 2: @@ -72,20 +72,20 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh raise ValueError("Input shapes not supported.") -def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: +def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: # output dtype is the dtype of the lhs float input lhs_rank, lhs_dtype = lhs_rank_dtype return lhs_dtype -def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: +def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: return # yapf: enable brevitas_matmul_rhs_group_quant_library = [ - brevitas〇matmul_rhs_group_quant〡shape, - brevitas〇matmul_rhs_group_quant〡dtype, - brevitas〇matmul_rhs_group_quant〡has_value_semantics] + quant〇matmul_rhs_group_quant〡shape, + quant〇matmul_rhs_group_quant〡dtype, + quant〇matmul_rhs_group_quant〡has_value_semantics] if __name__ == '__main__': @@ -100,7 +100,7 @@ def forward( rhs: torch.Tensor, rhs_scale: torch.Tensor, rhs_zero_point: torch.Tensor): - return torch.ops.brevitas.matmul_rhs_group_quant( + return torch.ops.quant.matmul_rhs_group_quant( lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width=8, rhs_group_size=128) mod = CustomOpExampleModule() @@ -109,6 +109,6 @@ def forward( module = torch_mlir.compile( mod, (torch.ones(3, 4), torch.ones(5, 4), torch.ones(1), torch.ones(1)), output_type="torch", - backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + backend_legal_ops=["quant.matmul_rhs_group_quant"], extra_library=brevitas_matmul_rhs_group_quant_library) print(module) diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py index 7f69c2029..a234c86d0 100644 --- a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -60,11 +60,11 @@ # Due a tracing issue this annotation needs to be # in the same module (== file) from which make_fx is called -# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant +# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant # and so we trace a placeholder first and then replace it post tracing @wrap(visible_to_make_fx=True) def matmul_rhs_group_quant_placeholder(*args, **kwargs): - return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs) + return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs) class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler): @@ -261,9 +261,7 @@ def transform_fx(fx_g): transform_fx(fx_g) replace_call_fn_target( - fx_g, - src=matmul_rhs_group_quant_placeholder, - target=torch.ops.brevitas.matmul_rhs_group_quant) + fx_g, src=matmul_rhs_group_quant_placeholder, target=torch.ops.quant.matmul_rhs_group_quant) fx_g.recompile() removed_none_indexes = _remove_nones(fx_g) @@ -319,7 +317,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir module = torch_mlir.compile( ts_g, (hidden_states_placeholder, inputs[1], inputs[2]), output_type="torch", - backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + backend_legal_ops=["quant.matmul_rhs_group_quant"], extra_library=brevitas_matmul_rhs_group_quant_library, use_tracing=False, verbose=False) @@ -342,7 +340,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir pkv0_placeholder, pkv1_placeholder), output_type="torch", - backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + backend_legal_ops=["quant.matmul_rhs_group_quant"], extra_library=brevitas_matmul_rhs_group_quant_library, use_tracing=False, verbose=False) diff --git a/src/brevitas_examples/llm/test_linear_mlir_export.py b/src/brevitas_examples/llm/test_linear_mlir_export.py index 417721a58..95b3dc127 100644 --- a/src/brevitas_examples/llm/test_linear_mlir_export.py +++ b/src/brevitas_examples/llm/test_linear_mlir_export.py @@ -18,11 +18,11 @@ # Due a tracing issue this annotation needs to be # in the same module (== file) from which make_fx is called -# We also can't directly annotate torch.ops.brevitas.matmul_rhs_group_quant +# We also can't directly annotate torch.ops.quant.matmul_rhs_group_quant # and so we trace a placeholder first and then replace it post tracing @wrap(visible_to_make_fx=True) def matmul_rhs_group_quant_placeholder(*args, **kwargs): - return torch.ops.brevitas.matmul_rhs_group_quant(*args, **kwargs) + return torch.ops.quant.matmul_rhs_group_quant(*args, **kwargs) class LinearWeightBlockQuantHandlerFwd(LinearWeightBlockQuantHandler): @@ -84,7 +84,7 @@ def quantize_and_export(args): replace_call_fn_target( traced_model, src=matmul_rhs_group_quant_placeholder, - target=torch.ops.brevitas.matmul_rhs_group_quant) + target=torch.ops.quant.matmul_rhs_group_quant) # print the output graph print(traced_model.graph) @@ -93,7 +93,7 @@ def quantize_and_export(args): traced_model, torch.randn(2, 128), output_type="torch", - backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + backend_legal_ops=["quant.matmul_rhs_group_quant"], extra_library=brevitas_matmul_rhs_group_quant_library, use_tracing=True, verbose=False)