Skip to content

Commit e3ef2c7

Browse files
Vysaratfacebook-github-bot
authored andcommitted
Update Executorch ops registration for rms_norm
Summary: Allows use of `aten::rms_norm` in place of `cadence::rms_norm` custom op, which is no longer needed. Differential Revision: D72485973
1 parent 56c8dc2 commit e3ef2c7

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

backends/cadence/aot/ops_registrations.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@
293293
"attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)"
294294
)
295295

296+
# Custom ops in aten namespace. RMSNorm is usually decomposed, so having
297+
# an out-variant is non-standard
298+
299+
lib_aten = Library("aten", "FRAGMENT")
300+
301+
lib_aten.define(
302+
"rms_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)"
303+
)
304+
296305

297306
@register_fake("cadence::quantize_per_tensor")
298307
def quantize_per_tensor_meta(
@@ -619,15 +628,6 @@ def linalg_vector_norm_meta(
619628
return X.new_empty([], dtype=X.dtype)
620629

621630

622-
@register_fake("cadence::rms_norm")
623-
def rms_norm_meta(
624-
X: torch.Tensor,
625-
eps: float,
626-
weight: torch.Tensor,
627-
) -> torch.Tensor:
628-
return X.new_empty(X.shape, dtype=X.dtype)
629-
630-
631631
@register_fake("cadence::requantize")
632632
def requantize_meta(
633633
input: torch.Tensor,

0 commit comments

Comments
 (0)