Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add initial f16 and f128 support to the x64 backend #9045

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions cranelift/codegen/src/isa/x64/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// extension annotations. Additionally, handling extension attributes this way allows clif
// functions that use them with the Winch calling convention to interact successfully with
// testing infrastructure.
// The results are also not packed if any of the types are `f16`. This is to simplify the
// implementation of `Inst::load`/`Inst::store` (which would otherwise require multiple
// instructions), and doesn't affect Winch itself as Winch doesn't support `f16` at all.
let uses_extension = params
.iter()
.any(|p| p.extension != ir::ArgumentExtension::None);
.any(|p| p.extension != ir::ArgumentExtension::None || p.value_type == types::F16);

for (ix, param) in params.iter().enumerate() {
let last_param = ix == params.len() - 1;
Expand Down Expand Up @@ -169,13 +172,23 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// https://godbolt.org/z/PhG3ob

if param.value_type.bits() > 64
&& !param.value_type.is_vector()
&& !(param.value_type.is_vector() || param.value_type.is_float())
&& !flags.enable_llvm_abi_extensions()
{
panic!(
"i128 args/return values not supported unless LLVM ABI extensions are enabled"
);
}
// As MSVC doesn't support f16/f128 there is no standard way to pass/return them with
// the Windows ABI. LLVM passes/returns them in XMM registers.
if matches!(param.value_type, types::F16 | types::F128)
&& is_fastcall
&& !flags.enable_llvm_abi_extensions()
{
panic!(
"f16/f128 args/return values not supported for windows_fastcall unless LLVM ABI extensions are enabled"
);
}

// Windows fastcall dictates that `__m128i` parameters to a function
// are passed indirectly as pointers, so handle that as a special
Expand Down Expand Up @@ -410,12 +423,20 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// bits as well -- see `Inst::store()`).
let ty = match ty {
types::I8 | types::I16 | types::I32 => types::I64,
// Stack slots are always at least 8 bytes, so it's fine to load 4 bytes instead of only
// two.
types::F16 => types::F32,
_ => ty,
};
Inst::load(ty, mem, into_reg, ExtKind::None)
}

fn gen_store_stack(mem: StackAMode, from_reg: Reg, ty: Type) -> Self::I {
let ty = match ty {
// See `gen_load_stack`.
types::F16 => types::F32,
_ => ty,
};
Inst::store(ty, from_reg, mem)
}

Expand Down Expand Up @@ -502,6 +523,11 @@ impl ABIMachineSpec for X64ABIMachineSpec {
}

fn gen_store_base_offset(base: Reg, offset: i32, from_reg: Reg, ty: Type) -> Self::I {
let ty = match ty {
// See `gen_load_stack`.
types::F16 => types::F32,
_ => ty,
};
let mem = Amode::imm_reg(offset, base);
Inst::store(ty, from_reg, mem)
}
Expand Down
38 changes: 31 additions & 7 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@
(rule (put_in_gpr val)
(if-let (value_type ty) val)
(if-let (type_register_class (RegisterClass.Xmm)) ty)
(bitcast_xmm_to_gpr ty (xmm_new (put_in_reg val))))
(bitcast_xmm_to_gpr (ty_bits ty) (xmm_new (put_in_reg val))))

;; Put a value into a `GprMem`.
;;
Expand Down Expand Up @@ -2252,8 +2252,10 @@

;; Performs an xor operation of the two operands specified.
(decl x64_xor_vector (Type Xmm XmmMem) Xmm)
(rule 1 (x64_xor_vector $F16 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F32 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F64 x y) (x64_xorpd x y))
(rule 1 (x64_xor_vector $F128 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F32X4 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F64X2 x y) (x64_xorpd x y))
(rule 0 (x64_xor_vector (multi_lane _ _) x y) (x64_pxor x y))
Expand Down Expand Up @@ -2304,6 +2306,9 @@
(rule 2 (x64_load $F64 addr _ext_kind)
(x64_movsd_load addr))

(rule 2 (x64_load $F128 addr _ext_kind)
(x64_movdqu_load addr))

(rule 2 (x64_load $F32X4 addr _ext_kind)
(x64_movups_load addr))

Expand Down Expand Up @@ -2719,6 +2724,10 @@
(_ Unit (emit (MInst.Imm size simm64 dst))))
dst))

;; `f16` immediates.
(rule 2 (imm $F16 (u64_nonzero bits))
(bitcast_gpr_to_xmm 16 (imm $I16 bits)))

;; `f32` immediates.
(rule 2 (imm $F32 (u64_nonzero bits))
(x64_movd_to_xmm (imm $I32 bits)))
Expand Down Expand Up @@ -2746,6 +2755,9 @@
(rule 0 (imm ty @ (multi_lane _bits _lanes) 0)
(xmm_to_reg (xmm_zero ty)))

;; Special case for `f16` zero immediates
(rule 2 (imm ty @ $F16 (u64_zero)) (xmm_zero ty))

;; Special case for `f32` zero immediates
(rule 2 (imm ty @ $F32 (u64_zero)) (xmm_zero ty))

Expand Down Expand Up @@ -5022,18 +5034,30 @@

;;;; Casting ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl bitcast_xmm_to_gpr (Type Xmm) Gpr)
(rule (bitcast_xmm_to_gpr $F32 src)
(decl bitcast_xmm_to_gpr (u8 Xmm) Gpr)
(rule (bitcast_xmm_to_gpr 16 src)
(x64_pextrw src 0))
(rule (bitcast_xmm_to_gpr 32 src)
(x64_movd_to_gpr src))
(rule (bitcast_xmm_to_gpr $F64 src)
(rule (bitcast_xmm_to_gpr 64 src)
(x64_movq_to_gpr src))

(decl bitcast_gpr_to_xmm (Type Gpr) Xmm)
(rule (bitcast_gpr_to_xmm $I32 src)
(decl bitcast_xmm_to_gprs (Xmm) ValueRegs)
(rule (bitcast_xmm_to_gprs src)
(value_regs (x64_movq_to_gpr src) (x64_movq_to_gpr (x64_pshufd src 0b11101110))))

(decl bitcast_gpr_to_xmm (u8 Gpr) Xmm)
(rule (bitcast_gpr_to_xmm 16 src)
(x64_pinsrw (xmm_uninit_value) src 0))
(rule (bitcast_gpr_to_xmm 32 src)
(x64_movd_to_xmm src))
(rule (bitcast_gpr_to_xmm $I64 src)
(rule (bitcast_gpr_to_xmm 64 src)
(x64_movq_to_xmm src))

(decl bitcast_gprs_to_xmm (ValueRegs) Xmm)
(rule (bitcast_gprs_to_xmm src)
(x64_punpcklqdq (x64_movq_to_xmm (value_regs_get_gpr src 0)) (x64_movq_to_xmm (value_regs_get_gpr src 1))))

;;;; Stack Addresses ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl stack_addr_impl (StackSlot Offset32) Gpr)
Expand Down
3 changes: 2 additions & 1 deletion cranelift/codegen/src/isa/x64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1428,10 +1428,11 @@ pub(crate) fn emit(
let op = match *ty {
types::F64 => SseOpcode::Movsd,
types::F32 => SseOpcode::Movsd,
types::F16 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movaps,
types::F64X2 => SseOpcode::Movapd,
ty => {
debug_assert!(ty.is_vector() && ty.bytes() == 16);
debug_assert!((ty.is_float() || ty.is_vector()) && ty.bytes() == 16);
SseOpcode::Movdqa
}
};
Expand Down
13 changes: 9 additions & 4 deletions cranelift/codegen/src/isa/x64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,12 @@ impl Inst {
}
RegClass::Float => {
let opcode = match ty {
types::F16 => panic!("loading a f16 requires multiple instructions"),
types::F32 => SseOpcode::Movss,
types::F64 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movups,
types::F64X2 => SseOpcode::Movupd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
_ => unimplemented!("unable to load type: {}", ty),
};
Inst::xmm_unary_rm_r(opcode, RegMem::mem(from_addr), to_reg)
Expand All @@ -650,11 +651,12 @@ impl Inst {
RegClass::Int => Inst::mov_r_m(OperandSize::from_ty(ty), from_reg, to_addr),
RegClass::Float => {
let opcode = match ty {
types::F16 => panic!("storing a f16 requires multiple instructions"),
types::F32 => SseOpcode::Movss,
types::F64 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movups,
types::F64X2 => SseOpcode::Movupd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
_ => unimplemented!("unable to store type: {}", ty),
};
Inst::xmm_mov_r_m(opcode, from_reg, to_addr)
Expand Down Expand Up @@ -1621,6 +1623,7 @@ impl PrettyPrint for Inst {
let suffix = match *ty {
types::F64 => "sd",
types::F32 => "ss",
types::F16 => "ss",
types::F32X4 => "aps",
types::F64X2 => "apd",
_ => "dqa",
Expand Down Expand Up @@ -2605,9 +2608,9 @@ impl MachInst for Inst {
// those, which may write more lanes that we need, but are specified to have
// zero-latency.
let opcode = match ty {
types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
types::F16 | types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
types::F64X2 => SseOpcode::Movapd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqa,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqa,
_ => unimplemented!("unable to move type: {}", ty),
};
Inst::xmm_unary_rm_r(opcode, RegMem::reg(src_reg), dst_reg)
Expand All @@ -2628,8 +2631,10 @@ impl MachInst for Inst {
types::I64 => Ok((&[RegClass::Int], &[types::I64])),
types::R32 => panic!("32-bit reftype pointer should never be seen on x86-64"),
types::R64 => Ok((&[RegClass::Int], &[types::R64])),
types::F16 => Ok((&[RegClass::Float], &[types::F16])),
types::F32 => Ok((&[RegClass::Float], &[types::F32])),
types::F64 => Ok((&[RegClass::Float], &[types::F64])),
types::F128 => Ok((&[RegClass::Float], &[types::F128])),
types::I128 => Ok((&[RegClass::Int, RegClass::Int], &[types::I64, types::I64])),
_ if ty.is_vector() => {
assert!(ty.bits() <= 128);
Expand Down
Loading
Loading