diff --git a/comfy/ops.py b/comfy/ops.py index 8e06942320d..06be6b48b50 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -255,9 +255,10 @@ def fp8_linear(self, input): tensor_2d = True input = input.unsqueeze(1) - + input_shape = input.shape + input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight @@ -269,23 +270,24 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype) else: scale_input = scale_input.to(input.device) - inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) + input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: - return o.reshape(input.shape[0], -1) + return o.reshape(input_shape[0], -1) - return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None