diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index dec6feb1b8..368f499aa8 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -293,6 +293,15 @@ "attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)" ) +# Custom ops in aten namespace. RMSNorm is usually decomposed, so having +# an out-variant is non-standard + +lib_aten = Library("aten", "FRAGMENT") + +lib_aten.define( + "rms_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)" +) + @register_fake("cadence::quantize_per_tensor") def quantize_per_tensor_meta( @@ -619,15 +628,6 @@ def linalg_vector_norm_meta( return X.new_empty([], dtype=X.dtype) -@register_fake("cadence::rms_norm") -def rms_norm_meta( - X: torch.Tensor, - eps: float, - weight: torch.Tensor, -) -> torch.Tensor: - return X.new_empty(X.shape, dtype=X.dtype) - - @register_fake("cadence::requantize") def requantize_meta( input: torch.Tensor,