Skip to content

Commit

Permalink
Add initial f16 and f128 support to the x64 backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beetrees committed Jul 30, 2024
1 parent 12fc764 commit 20a8e4c
Show file tree
Hide file tree
Showing 20 changed files with 969 additions and 44 deletions.
30 changes: 28 additions & 2 deletions cranelift/codegen/src/isa/x64/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// extension annotations. Additionally, handling extension attributes this way allows clif
// functions that use them with the Winch calling convention to interact successfully with
// testing infrastructure.
// The results are also not packed if any of the types are `f16`. This is to simplify the
// implementation of `Inst::load`/`Inst::store` (which would otherwise require multiple
// instructions), and doesn't affect Winch itself as Winch doesn't support `f16` at all.
let uses_extension = params
.iter()
.any(|p| p.extension != ir::ArgumentExtension::None);
.any(|p| p.extension != ir::ArgumentExtension::None || p.value_type == types::F16);

for (ix, param) in params.iter().enumerate() {
let last_param = ix == params.len() - 1;
Expand Down Expand Up @@ -169,13 +172,23 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// https://godbolt.org/z/PhG3ob

if param.value_type.bits() > 64
&& !param.value_type.is_vector()
&& !(param.value_type.is_vector() || param.value_type.is_float())
&& !flags.enable_llvm_abi_extensions()
{
panic!(
"i128 args/return values not supported unless LLVM ABI extensions are enabled"
);
}
// As MSVC doesn't support f16/f128 there is no standard way to pass/return them with
// the Windows ABI. LLVM passes/returns them in XMM registers.
if matches!(param.value_type, types::F16 | types::F128)
&& is_fastcall
&& !flags.enable_llvm_abi_extensions()
{
panic!(
"f16/f128 args/return values not supported for windows_fastcall unless LLVM ABI extensions are enabled"
);
}

// Windows fastcall dictates that `__m128i` parameters to a function
// are passed indirectly as pointers, so handle that as a special
Expand Down Expand Up @@ -410,12 +423,20 @@ impl ABIMachineSpec for X64ABIMachineSpec {
// bits as well -- see `Inst::store()`).
let ty = match ty {
types::I8 | types::I16 | types::I32 => types::I64,
// Stack slots are always at least 8 bytes, so it's fine to load 4 bytes instead of only
// two.
types::F16 => types::F32,
_ => ty,
};
Inst::load(ty, mem, into_reg, ExtKind::None)
}

fn gen_store_stack(mem: StackAMode, from_reg: Reg, ty: Type) -> Self::I {
let ty = match ty {
// See `gen_load_stack`.
types::F16 => types::F32,
_ => ty,
};
Inst::store(ty, from_reg, mem)
}

Expand Down Expand Up @@ -502,6 +523,11 @@ impl ABIMachineSpec for X64ABIMachineSpec {
}

fn gen_store_base_offset(base: Reg, offset: i32, from_reg: Reg, ty: Type) -> Self::I {
let ty = match ty {
// See `gen_load_stack`.
types::F16 => types::F32,
_ => ty,
};
let mem = Amode::imm_reg(offset, base);
Inst::store(ty, from_reg, mem)
}
Expand Down
38 changes: 31 additions & 7 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@
(rule (put_in_gpr val)
(if-let (value_type ty) val)
(if-let (type_register_class (RegisterClass.Xmm)) ty)
(bitcast_xmm_to_gpr ty (xmm_new (put_in_reg val))))
(bitcast_xmm_to_gpr (ty_bits ty) (xmm_new (put_in_reg val))))

;; Put a value into a `GprMem`.
;;
Expand Down Expand Up @@ -2252,8 +2252,10 @@

;; Performs an xor operation of the two operands specified.
(decl x64_xor_vector (Type Xmm XmmMem) Xmm)
(rule 1 (x64_xor_vector $F16 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F32 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F64 x y) (x64_xorpd x y))
(rule 1 (x64_xor_vector $F128 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F32X4 x y) (x64_xorps x y))
(rule 1 (x64_xor_vector $F64X2 x y) (x64_xorpd x y))
(rule 0 (x64_xor_vector (multi_lane _ _) x y) (x64_pxor x y))
Expand Down Expand Up @@ -2304,6 +2306,9 @@
(rule 2 (x64_load $F64 addr _ext_kind)
(x64_movsd_load addr))

(rule 2 (x64_load $F128 addr _ext_kind)
(x64_movdqu_load addr))

(rule 2 (x64_load $F32X4 addr _ext_kind)
(x64_movups_load addr))

Expand Down Expand Up @@ -2719,6 +2724,10 @@
(_ Unit (emit (MInst.Imm size simm64 dst))))
dst))

;; `f16` immediates.
(rule 2 (imm $F16 (u64_nonzero bits))
(bitcast_gpr_to_xmm 16 (imm $I16 bits)))

;; `f32` immediates.
(rule 2 (imm $F32 (u64_nonzero bits))
(x64_movd_to_xmm (imm $I32 bits)))
Expand Down Expand Up @@ -2746,6 +2755,9 @@
(rule 0 (imm ty @ (multi_lane _bits _lanes) 0)
(xmm_to_reg (xmm_zero ty)))

;; Special case for `f16` zero immediates
(rule 2 (imm ty @ $F16 (u64_zero)) (xmm_zero ty))

;; Special case for `f32` zero immediates
(rule 2 (imm ty @ $F32 (u64_zero)) (xmm_zero ty))

Expand Down Expand Up @@ -5022,18 +5034,30 @@

;;;; Casting ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl bitcast_xmm_to_gpr (Type Xmm) Gpr)
(rule (bitcast_xmm_to_gpr $F32 src)
(decl bitcast_xmm_to_gpr (u8 Xmm) Gpr)
(rule (bitcast_xmm_to_gpr 16 src)
(x64_pextrw src 0))
(rule (bitcast_xmm_to_gpr 32 src)
(x64_movd_to_gpr src))
(rule (bitcast_xmm_to_gpr $F64 src)
(rule (bitcast_xmm_to_gpr 64 src)
(x64_movq_to_gpr src))

(decl bitcast_gpr_to_xmm (Type Gpr) Xmm)
(rule (bitcast_gpr_to_xmm $I32 src)
(decl bitcast_xmm_to_gprs (Xmm) ValueRegs)
(rule (bitcast_xmm_to_gprs src)
(value_regs (x64_movq_to_gpr src) (x64_movq_to_gpr (x64_pshufd src 0b11101110))))

(decl bitcast_gpr_to_xmm (u8 Gpr) Xmm)
(rule (bitcast_gpr_to_xmm 16 src)
(x64_pinsrw (xmm_uninit_value) src 0))
(rule (bitcast_gpr_to_xmm 32 src)
(x64_movd_to_xmm src))
(rule (bitcast_gpr_to_xmm $I64 src)
(rule (bitcast_gpr_to_xmm 64 src)
(x64_movq_to_xmm src))

(decl bitcast_gprs_to_xmm (ValueRegs) Xmm)
(rule (bitcast_gprs_to_xmm src)
(x64_punpcklqdq (x64_movq_to_xmm (value_regs_get_gpr src 0)) (x64_movq_to_xmm (value_regs_get_gpr src 1))))

;;;; Stack Addresses ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl stack_addr_impl (StackSlot Offset32) Gpr)
Expand Down
3 changes: 2 additions & 1 deletion cranelift/codegen/src/isa/x64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1428,10 +1428,11 @@ pub(crate) fn emit(
let op = match *ty {
types::F64 => SseOpcode::Movsd,
types::F32 => SseOpcode::Movsd,
types::F16 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movaps,
types::F64X2 => SseOpcode::Movapd,
ty => {
debug_assert!(ty.is_vector() && ty.bytes() == 16);
debug_assert!((ty.is_float() || ty.is_vector()) && ty.bytes() == 16);
SseOpcode::Movdqa
}
};
Expand Down
13 changes: 9 additions & 4 deletions cranelift/codegen/src/isa/x64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,12 @@ impl Inst {
}
RegClass::Float => {
let opcode = match ty {
types::F16 => panic!("loading a f16 requires multiple instructions"),
types::F32 => SseOpcode::Movss,
types::F64 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movups,
types::F64X2 => SseOpcode::Movupd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
_ => unimplemented!("unable to load type: {}", ty),
};
Inst::xmm_unary_rm_r(opcode, RegMem::mem(from_addr), to_reg)
Expand All @@ -650,11 +651,12 @@ impl Inst {
RegClass::Int => Inst::mov_r_m(OperandSize::from_ty(ty), from_reg, to_addr),
RegClass::Float => {
let opcode = match ty {
types::F16 => panic!("storing a f16 requires multiple instructions"),
types::F32 => SseOpcode::Movss,
types::F64 => SseOpcode::Movsd,
types::F32X4 => SseOpcode::Movups,
types::F64X2 => SseOpcode::Movupd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
_ => unimplemented!("unable to store type: {}", ty),
};
Inst::xmm_mov_r_m(opcode, from_reg, to_addr)
Expand Down Expand Up @@ -1621,6 +1623,7 @@ impl PrettyPrint for Inst {
let suffix = match *ty {
types::F64 => "sd",
types::F32 => "ss",
types::F16 => "ss",
types::F32X4 => "aps",
types::F64X2 => "apd",
_ => "dqa",
Expand Down Expand Up @@ -2605,9 +2608,9 @@ impl MachInst for Inst {
// those, which may write more lanes that we need, but are specified to have
// zero-latency.
let opcode = match ty {
types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
types::F16 | types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
types::F64X2 => SseOpcode::Movapd,
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqa,
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqa,
_ => unimplemented!("unable to move type: {}", ty),
};
Inst::xmm_unary_rm_r(opcode, RegMem::reg(src_reg), dst_reg)
Expand All @@ -2628,8 +2631,10 @@ impl MachInst for Inst {
types::I64 => Ok((&[RegClass::Int], &[types::I64])),
types::R32 => panic!("32-bit reftype pointer should never be seen on x86-64"),
types::R64 => Ok((&[RegClass::Int], &[types::R64])),
types::F16 => Ok((&[RegClass::Float], &[types::F16])),
types::F32 => Ok((&[RegClass::Float], &[types::F32])),
types::F64 => Ok((&[RegClass::Float], &[types::F64])),
types::F128 => Ok((&[RegClass::Float], &[types::F128])),
types::I128 => Ok((&[RegClass::Int, RegClass::Int], &[types::I64, types::I64])),
_ if ty.is_vector() => {
assert!(ty.bits() <= 128);
Expand Down
Loading

0 comments on commit 20a8e4c

Please # to comment.