Skip to content

Commit

Permalink
[X86] AVX512FP16 instructions enabling 1/6
Browse files Browse the repository at this point in the history
1. Enable FP16 type support and basic declarations used by following patches.
2. Enable new instructions VMOVW and VMOVSH.

Ref.: https://software.intel.com/content/www/us/en/develop/download/intel-avx512-fp16-architecture-specification.html

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D105263
  • Loading branch information
phoebewang committed Aug 10, 2021
1 parent b978df4 commit 6f7f5b5
Show file tree
Hide file tree
Showing 73 changed files with 5,265 additions and 179 deletions.
2 changes: 2 additions & 0 deletions clang/docs/ClangCommandLineReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3553,6 +3553,8 @@ X86

.. option:: -mavx512f, -mno-avx512f

.. option:: -mavx512fp16, -mno-avx512fp16

.. option:: -mavx512ifma, -mno-avx512ifma

.. option:: -mavx512pf, -mno-avx512pf
Expand Down
1 change: 1 addition & 0 deletions clang/docs/LanguageExtensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ targets pending ABI standardization:
* 64-bit ARM (AArch64)
* AMDGPU
* SPIR
* X86 (Only available under feature AVX512-FP16)

``_Float16`` will be supported on more targets as they define ABIs for it.

Expand Down
2 changes: 1 addition & 1 deletion clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ CUDA Support in Clang
X86 Support in Clang
--------------------

- ...
- Support for ``AVX512-FP16`` instructions has been added.

Internal API Changes
--------------------
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/BuiltinsX86.def
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,10 @@ TARGET_BUILTIN(__builtin_ia32_vp2intersect_d_512, "vV16iV16iUs*Us*", "nV:512:",
TARGET_BUILTIN(__builtin_ia32_vp2intersect_d_256, "vV8iV8iUc*Uc*", "nV:256:", "avx512vp2intersect,avx512vl")
TARGET_BUILTIN(__builtin_ia32_vp2intersect_d_128, "vV4iV4iUc*Uc*", "nV:128:", "avx512vp2intersect,avx512vl")

// AVX512 fp16 intrinsics
TARGET_BUILTIN(__builtin_ia32_loadsh128_mask, "V8xV8x*V8xUc", "nV:128:", "avx512fp16")
TARGET_BUILTIN(__builtin_ia32_storesh128_mask, "vV8x*V8xUc", "nV:128:", "avx512fp16")

// generic select intrinsics
TARGET_BUILTIN(__builtin_ia32_selectb_128, "V16cUsV16cV16c", "ncV:128:", "avx512bw,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectb_256, "V32cUiV32cV32c", "ncV:256:", "avx512bw,avx512vl")
Expand All @@ -1859,6 +1863,9 @@ TARGET_BUILTIN(__builtin_ia32_selectw_512, "V32sUiV32sV32s", "ncV:512:", "avx512
TARGET_BUILTIN(__builtin_ia32_selectd_128, "V4iUcV4iV4i", "ncV:128:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectd_256, "V8iUcV8iV8i", "ncV:256:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectd_512, "V16iUsV16iV16i", "ncV:512:", "avx512f")
TARGET_BUILTIN(__builtin_ia32_selectph_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectph_256, "V16xUsV16xV16x", "ncV:256:", "avx512fp16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectph_512, "V32xUiV32xV32x", "ncV:512:", "avx512fp16")
TARGET_BUILTIN(__builtin_ia32_selectq_128, "V2OiUcV2OiV2Oi", "ncV:128:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectq_256, "V4OiUcV4OiV4Oi", "ncV:256:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectq_512, "V8OiUcV8OiV8Oi", "ncV:512:", "avx512f")
Expand All @@ -1868,6 +1875,7 @@ TARGET_BUILTIN(__builtin_ia32_selectps_512, "V16fUsV16fV16f", "ncV:512:", "avx51
TARGET_BUILTIN(__builtin_ia32_selectpd_128, "V2dUcV2dV2d", "ncV:128:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectpd_256, "V4dUcV4dV4d", "ncV:256:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectpd_512, "V8dUcV8dV8d", "ncV:512:", "avx512f")
TARGET_BUILTIN(__builtin_ia32_selectsh_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16")
TARGET_BUILTIN(__builtin_ia32_selectss_128, "V4fUcV4fV4f", "ncV:128:", "avx512f")
TARGET_BUILTIN(__builtin_ia32_selectsd_128, "V2dUcV2dV2d", "ncV:128:", "avx512f")

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -4165,6 +4165,8 @@ def mavx512dq : Flag<["-"], "mavx512dq">, Group<m_x86_Features_Group>;
def mno_avx512dq : Flag<["-"], "mno-avx512dq">, Group<m_x86_Features_Group>;
def mavx512er : Flag<["-"], "mavx512er">, Group<m_x86_Features_Group>;
def mno_avx512er : Flag<["-"], "mno-avx512er">, Group<m_x86_Features_Group>;
def mavx512fp16 : Flag<["-"], "mavx512fp16">, Group<m_x86_Features_Group>;
def mno_avx512fp16 : Flag<["-"], "mno-avx512fp16">, Group<m_x86_Features_Group>;
def mavx512ifma : Flag<["-"], "mavx512ifma">, Group<m_x86_Features_Group>;
def mno_avx512ifma : Flag<["-"], "mno-avx512ifma">, Group<m_x86_Features_Group>;
def mavx512pf : Flag<["-"], "mavx512pf">, Group<m_x86_Features_Group>;
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/Basic/Targets/X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasAVX512BF16 = true;
} else if (Feature == "+avx512er") {
HasAVX512ER = true;
} else if (Feature == "+avx512fp16") {
HasAVX512FP16 = true;
HasFloat16 = true;
} else if (Feature == "+avx512pf") {
HasAVX512PF = true;
} else if (Feature == "+avx512dq") {
Expand Down Expand Up @@ -668,6 +671,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__AVX512BF16__");
if (HasAVX512ER)
Builder.defineMacro("__AVX512ER__");
if (HasAVX512FP16)
Builder.defineMacro("__AVX512FP16__");
if (HasAVX512PF)
Builder.defineMacro("__AVX512PF__");
if (HasAVX512DQ)
Expand Down Expand Up @@ -856,6 +861,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
.Case("avx512vnni", true)
.Case("avx512bf16", true)
.Case("avx512er", true)
.Case("avx512fp16", true)
.Case("avx512pf", true)
.Case("avx512dq", true)
.Case("avx512bitalg", true)
Expand Down Expand Up @@ -948,6 +954,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
.Case("avx512vnni", HasAVX512VNNI)
.Case("avx512bf16", HasAVX512BF16)
.Case("avx512er", HasAVX512ER)
.Case("avx512fp16", HasAVX512FP16)
.Case("avx512pf", HasAVX512PF)
.Case("avx512dq", HasAVX512DQ)
.Case("avx512bitalg", HasAVX512BITALG)
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/X86.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
bool HasAVX512CD = false;
bool HasAVX512VPOPCNTDQ = false;
bool HasAVX512VNNI = false;
bool HasAVX512FP16 = false;
bool HasAVX512BF16 = false;
bool HasAVX512ER = false;
bool HasAVX512PF = false;
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12671,6 +12671,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
case X86::BI__builtin_ia32_storeups512_mask:
return EmitX86MaskedStore(*this, Ops, Align(1));

case X86::BI__builtin_ia32_storesh128_mask:
case X86::BI__builtin_ia32_storess128_mask:
case X86::BI__builtin_ia32_storesd128_mask:
return EmitX86MaskedStore(*this, Ops, Align(1));
Expand Down Expand Up @@ -12806,6 +12807,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
case X86::BI__builtin_ia32_loaddqudi512_mask:
return EmitX86MaskedLoad(*this, Ops, Align(1));

case X86::BI__builtin_ia32_loadsh128_mask:
case X86::BI__builtin_ia32_loadss128_mask:
case X86::BI__builtin_ia32_loadsd128_mask:
return EmitX86MaskedLoad(*this, Ops, Align(1));
Expand Down Expand Up @@ -13685,13 +13687,17 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
case X86::BI__builtin_ia32_selectq_128:
case X86::BI__builtin_ia32_selectq_256:
case X86::BI__builtin_ia32_selectq_512:
case X86::BI__builtin_ia32_selectph_128:
case X86::BI__builtin_ia32_selectph_256:
case X86::BI__builtin_ia32_selectph_512:
case X86::BI__builtin_ia32_selectps_128:
case X86::BI__builtin_ia32_selectps_256:
case X86::BI__builtin_ia32_selectps_512:
case X86::BI__builtin_ia32_selectpd_128:
case X86::BI__builtin_ia32_selectpd_256:
case X86::BI__builtin_ia32_selectpd_512:
return EmitX86Select(*this, Ops[0], Ops[1], Ops[2]);
case X86::BI__builtin_ia32_selectsh_128:
case X86::BI__builtin_ia32_selectss_128:
case X86::BI__builtin_ia32_selectsd_128: {
Value *A = Builder.CreateExtractElement(Ops[1], (uint64_t)0);
Expand Down
74 changes: 62 additions & 12 deletions clang/lib/CodeGen/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,8 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase,
Hi = Integer;
} else if (k >= BuiltinType::Bool && k <= BuiltinType::LongLong) {
Current = Integer;
} else if (k == BuiltinType::Float || k == BuiltinType::Double) {
} else if (k == BuiltinType::Float || k == BuiltinType::Double ||
k == BuiltinType::Float16) {
Current = SSE;
} else if (k == BuiltinType::LongDouble) {
const llvm::fltSemantics *LDF = &getTarget().getLongDoubleFormat();
Expand Down Expand Up @@ -2943,7 +2944,7 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase,
Current = Integer;
else if (Size <= 128)
Lo = Hi = Integer;
} else if (ET == getContext().FloatTy) {
} else if (ET->isFloat16Type() || ET == getContext().FloatTy) {
Current = SSE;
} else if (ET == getContext().DoubleTy) {
Lo = Hi = SSE;
Expand Down Expand Up @@ -3396,27 +3397,76 @@ static bool ContainsFloatAtOffset(llvm::Type *IRType, unsigned IROffset,
return false;
}

/// ContainsHalfAtOffset - Return true if the specified LLVM IR type has a
/// half member at the specified offset. For example, {int,{half}} has a
/// half at offset 4. It is conservatively correct for this routine to return
/// false.
/// FIXME: Merge with ContainsFloatAtOffset
static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
const llvm::DataLayout &TD) {
// Base case if we find a float.
if (IROffset == 0 && IRType->isHalfTy())
return true;

// If this is a struct, recurse into the field at the specified offset.
if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
const llvm::StructLayout *SL = TD.getStructLayout(STy);
unsigned Elt = SL->getElementContainingOffset(IROffset);
IROffset -= SL->getElementOffset(Elt);
return ContainsHalfAtOffset(STy->getElementType(Elt), IROffset, TD);
}

// If this is an array, recurse into the field at the specified offset.
if (llvm::ArrayType *ATy = dyn_cast<llvm::ArrayType>(IRType)) {
llvm::Type *EltTy = ATy->getElementType();
unsigned EltSize = TD.getTypeAllocSize(EltTy);
IROffset -= IROffset / EltSize * EltSize;
return ContainsHalfAtOffset(EltTy, IROffset, TD);
}

return false;
}

/// GetSSETypeAtOffset - Return a type that will be passed by the backend in the
/// low 8 bytes of an XMM register, corresponding to the SSE class.
llvm::Type *X86_64ABIInfo::
GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
QualType SourceTy, unsigned SourceOffset) const {
// The only three choices we have are either double, <2 x float>, or float. We
// pass as float if the last 4 bytes is just padding. This happens for
// structs that contain 3 floats.
if (BitsContainNoUserData(SourceTy, SourceOffset*8+32,
SourceOffset*8+64, getContext()))
return llvm::Type::getFloatTy(getVMContext());
// If the high 32 bits are not used, we have three choices. Single half,
// single float or two halfs.
if (BitsContainNoUserData(SourceTy, SourceOffset * 8 + 32,
SourceOffset * 8 + 64, getContext())) {
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()))
return llvm::Type::getFloatTy(getVMContext());
if (ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()),
2);

return llvm::Type::getHalfTy(getVMContext());
}

// We want to pass as <2 x float> if the LLVM IR type contains a float at
// offset+0 and offset+4. Walk the LLVM IR type to find out if this is the
// offset+0 and offset+4. Walk the LLVM IR type to find out if this is the
// case.
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) &&
ContainsFloatAtOffset(IRType, IROffset+4, getDataLayout()))
ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getFloatTy(getVMContext()),
2);

// We want to pass as <4 x half> if the LLVM IR type contains a half at
// offset+0, +2, +4. Walk the LLVM IR type to find out if this is the case.
if (ContainsHalfAtOffset(IRType, IROffset, getDataLayout()) &&
ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()) &&
ContainsHalfAtOffset(IRType, IROffset + 4, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);

// We want to pass as <4 x half> if the LLVM IR type contains a mix of float
// and half.
// FIXME: Do we have a better representation for the mixed type?
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) ||
ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);

return llvm::Type::getDoubleTy(getVMContext());
}

Expand Down Expand Up @@ -3521,11 +3571,11 @@ GetX86_64ByValArgumentPair(llvm::Type *Lo, llvm::Type *Hi,
// struct.
if (HiStart != 8) {
// There are usually two sorts of types the ABI generation code can produce
// for the low part of a pair that aren't 8 bytes in size: float or
// for the low part of a pair that aren't 8 bytes in size: half, float or
// i8/i16/i32. This can also include pointers when they are 32-bit (X32 and
// NaCl).
// Promote these to a larger type.
if (Lo->isFloatTy())
if (Lo->isHalfTy() || Lo->isFloatTy())
Lo = llvm::Type::getDoubleTy(Lo->getContext());
else {
assert((Lo->isIntegerTy() || Lo->isPointerTy())
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Headers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(files
avx512dqintrin.h
avx512erintrin.h
avx512fintrin.h
avx512fp16intrin.h
avx512ifmaintrin.h
avx512ifmavlintrin.h
avx512pfintrin.h
Expand All @@ -28,6 +29,7 @@ set(files
avx512vlbwintrin.h
avx512vlcdintrin.h
avx512vldqintrin.h
avx512vlfp16intrin.h
avx512vlintrin.h
avx512vp2intersectintrin.h
avx512vlvp2intersectintrin.h
Expand Down
Loading

0 comments on commit 6f7f5b5

Please # to comment.