diff --git a/src/metal.jl b/src/metal.jl index cd8c6f3f..9b0f332e 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -37,6 +37,15 @@ llvm_datalayout(target::MetalCompilerTarget) = needs_byval(job::CompilerJob{MetalCompilerTarget}) = false +module AS + const Device = 1 + const Constant = 2 + const ThreadGroup = 3 + const Thread = 4 + const ThreadGroup_ImgBlock = 5 + const Ray = 6 +end + ## job @@ -45,7 +54,7 @@ needs_byval(job::CompilerJob{MetalCompilerTarget}) = false runtime_slug(job::CompilerJob{MetalCompilerTarget}) = "metal-macos$(job.config.target.macos)" isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) = - return startswith(fn, "air.") + return startswith(fn, "air.") || startswith(fn, "julia.air.") function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function) entry_fn = LLVM.name(entry) @@ -199,12 +208,16 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L end # perform codegen passes that would normally run during machine code emission - # XXX: codegen passes don't seem available in the new pass manager yet - @dispose pm=ModulePassManager() begin - expand_reductions!(pm) - run!(pm, mod) + if LLVM.has_oldpm() + # XXX: codegen passes don't seem available in the new pass manager yet + @dispose pm=ModulePassManager() begin + expand_reductions!(pm) + run!(pm, mod) + end end + legalize_atomics!(job, mod) + return functions(mod)[entry_fn] end @@ -214,6 +227,272 @@ end return nothing end +function legalize_atomics!(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) + # gather atomic instructions + worklist = LLVM.Instruction[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + is_atomic(inst) || continue + push!(worklist, inst) + end + isempty(worklist) && return + + T_void = LLVM.VoidType() + T_i32 = LLVM.Int32Type() + T_i1 = LLVM.Int1Type() + + mem_name = Dict( + AS.Device => "global", + AS.ThreadGroup => "local", + ) + + mem_flags = Dict( + :none => 0, + :device => 1, + :threadgroup => 2, + :texture => 4, + :threadgroup_imageblock => 8, + :object_data => 16, + ) + + llvm_to_metal_ordering = Dict( + LLVM.API.LLVMAtomicOrderingMonotonic => 0, # relaxed + #LLVM.API.LLVMAtomicOrderingConsume => 1, # not supported + LLVM.API.LLVMAtomicOrderingAcquire => 2, + LLVM.API.LLVMAtomicOrderingRelease => 3, + LLVM.API.LLVMAtomicOrderingAcquireRelease => 4, + LLVM.API.LLVMAtomicOrderingSequentiallyConsistent => 5, + ) + function memory_order(ord::LLVM.API.LLVMAtomicOrdering) + if !haskey(llvm_to_metal_ordering, ord) + error("Unsupported memory order: $(ord))") + end + llvm_to_metal_ordering[ord] + end + memory_order(inst) = memory_order(LLVM.ordering(inst)) + + thread_scopes = Dict( + :thread => 0, + :threadgroup => 1, + :device => 2, + :simdgroup => 4, + ) + function thread_scope(inst, as=nothing) + if syncscope(inst) == SyncScope("system") + # when using the default syncscope, figure out a sane one by looking at the + # type of memory used, just like Metal does + if as === nothing + # be conservative + thread_scopes[:device] + elseif as == AS.Device + thread_scopes[:device] + elseif as == AS.ThreadGroup + thread_scopes[:threadgroup] + else + error("Do not know how to determine syncscope for address space $as") + end + else + scope = syncscope(inst) + if scope == SyncScope("thread") + thread_scopes[:thread] + elseif scope == SyncScope("threadgroup") + thread_scopes[:threadgroup] + elseif scope == SyncScope("device") + thread_scopes[:device] + elseif scope == SyncScope("simdgroup") + thread_scopes[:simdgroup] + else + error("Unsupported syncscope for instruction $inst") + end + end + end + + # legalize atomic operations + for inst in worklist + bb = LLVM.parent(inst) + fun = LLVM.parent(bb) + + if isa(inst, LLVM.StoreInst) + val, dst = operands(inst) + val_typ = value_type(val) + dst_typ = value_type(dst) + as = addrspace(dst_typ) + + haskey(mem_name, as) || + error("Unsupported address space for atomic operations: $as") + intr_name = "air.atomic.$(mem_name[as]).store.$(string(val_typ))" + intr_ty = LLVM.FunctionType(LLVM.VoidType(), [dst_typ, val_typ, T_i32, T_i32, T_i1]) + intr = LLVM.Function(mod, intr_name, intr_ty) + + scope = ConstantInt(T_i32, thread_scope(inst, as)) + order = ConstantInt(T_i32, memory_order(inst)) + unknown = ConstantInt(T_i1, 1) + + @dispose builder=LLVM.IRBuilder() begin + position!(builder, inst) + new_inst = call!(builder, intr_ty, intr, [dst, val, order, scope, unknown]) + replace_uses!(inst, new_inst) + erase!(inst) + end + elseif isa(inst, LLVM.LoadInst) + src, = operands(inst) + src_typ = value_type(src) + val_typ = value_type(inst) + as = addrspace(src_typ) + + haskey(mem_name, as) || + error("Unsupported address space for atomic operations: $as") + intr_name = "air.atomic.$(mem_name[as]).load.$(string(val_typ))" + intr_ty = LLVM.FunctionType(val_typ, [src_typ, T_i32, T_i32, T_i1]) + intr = LLVM.Function(mod, intr_name, intr_ty) + + scope = ConstantInt(T_i32, thread_scope(inst, as)) + order = ConstantInt(T_i32, memory_order(inst)) + unknown = ConstantInt(T_i1, 1) + + @dispose builder=LLVM.IRBuilder() begin + position!(builder, inst) + new_inst = call!(builder, intr_ty, intr, [src, order, scope, unknown]) + replace_uses!(inst, new_inst) + erase!(inst) + end + + elseif isa(inst, LLVM.AtomicRMWInst) + ptr, val = operands(inst) + ptr_typ = value_type(ptr) + val_typ = value_type(val) + as = addrspace(ptr_typ) + + op = binop(inst) + # XXX: we don't know the sign of integer operands, so default to signed. + # it shouldn't matter for 2's complement... + opstr = if op == LLVM.API.LLVMAtomicRMWBinOpXchg + "xchg" + elseif op == LLVM.API.LLVMAtomicRMWBinOpAdd + "add.s" + elseif op == LLVM.API.LLVMAtomicRMWBinOpSub + "sub.s" + elseif op == LLVMAtomicRMWBinOpAnd + "and.s" + elseif op == LLVMAtomicRMWBinOpOr + "or.s" + elseif op == LLVMAtomicRMWBinOpXor + "xor.s" + elseif op == LLVMAtomicRMWBinOpMax + "max.s" + elseif op == LLVMAtomicRMWBinOpMin + "min.s" + elseif op == LLVMAtomicRMWBinOpUMax + "max.u" + elseif op == LLVMAtomicRMWBinOpUMin + "min.u" + elseif op == LLVMAtomicRMWBinOpFAdd + "add" + elseif op == LLVMAtomicRMWBinOpFSub + "sub" + else + error("Unsupported atomic operation: $(binop(inst))") + end + # Metal 2.4: min/max also supported on unsigned i64 + + haskey(mem_name, as) || + error("Unsupported address space for atomic operations: $as") + intr_name = "air.atomic.$(mem_name[as]).$opstr.$(string(val_typ))" + intr_ty = LLVM.FunctionType(val_typ, [ptr_typ, val_typ, T_i32, T_i32, T_i1]) + intr = LLVM.Function(mod, intr_name, intr_ty) + + scope = ConstantInt(T_i32, thread_scope(inst, as)) + order = ConstantInt(T_i32, memory_order(inst)) + unknown = ConstantInt(T_i1, 1) + + @dispose builder=LLVM.IRBuilder() begin + position!(builder, inst) + new_inst = call!(builder, intr_ty, intr, [ptr, val, order, scope, unknown]) + replace_uses!(inst, new_inst) + erase!(inst) + end + elseif isa(inst, LLVM.AtomicCmpXchgInst) + ptr, cmp, val = operands(inst) + ptr_typ = value_type(ptr) + val_typ = value_type(val) + as = addrspace(ptr_typ) + + # LLVM returns the expected value, while AIR uses a box + cmp_box_typ = LLVM.PointerType(val_typ) + + haskey(mem_name, as) || + error("Unsupported address space for atomic operations: $as") + # XXX: Metal only supports weak cmpxchg, but UnsafeAtomics.jl only emits strong. + intr_name = "air.atomic.$(mem_name[as]).cmpxchg.weak.$(string(val_typ))" + intr_ty = LLVM.FunctionType(T_i1, [ptr_typ, cmp_box_typ, val_typ, T_i32, T_i32, T_i32, T_i1]) + intr = LLVM.Function(mod, intr_name, intr_ty) + + scope = ConstantInt(T_i32, thread_scope(inst, as)) + success = ConstantInt(T_i32, memory_order(success_ordering(inst))) + failure = ConstantInt(T_i32, memory_order(failure_ordering(inst))) + unknown = ConstantInt(T_i1, 1) + + @dispose builder=LLVM.IRBuilder() begin + # allocate boxes for the result struct and cmp value + firstbb = first(blocks(fun)) + firstinst = first(instructions(firstbb)) + position!(builder, firstinst) + struct_ty = LLVM.StructType([val_typ, T_i1]) + struct_box = alloca!(builder, struct_ty) + cmp_box = alloca!(builder, val_typ) + + # emit a call to the AIR intrinsic + position!(builder, inst) + store!(builder, cmp, cmp_box) + new_inst = call!(builder, intr_ty, intr, [ptr, cmp_box, val, success, failure, scope, unknown]) + new_val = load!(builder, val_typ, cmp_box) + + # return a struct + val_field = struct_gep!(builder, struct_ty, struct_box, 0) + store!(builder, new_val, val_field) + success_field = struct_gep!(builder, struct_ty, struct_box, 1) + store!(builder, new_inst, success_field) + struct_val = load!(builder, struct_ty, struct_box) + replace_uses!(inst, struct_val) + erase!(inst) + end + + # optimize unnecessary boxes away + run!("instcombine", fun) + elseif isa(inst, LLVM.FenceInst) + # only on 3.2+ + # cannot derive the scope... + + intr_name = "air.atomic.thread_fence" + intr_ty = LLVM.FunctionType(T_void, [T_i32, T_i32, T_i32]) + intr = LLVM.Function(mod, intr_name, intr_ty) + + # XXX: we do not know which which address spaces to fence, as that information + # is not encoded in LLVM IR, so we're conservative and target all ASs. + # Also see: https://llvm.org/docs/MemoryModelRelaxationAnnotations.html + # https://llvm.org/docs/AMDGPUUsage.html#fence-and-address-spaces + # TODO: mmra + # XXX: AS should provide AS scope OR MRRA if missing, + # thread scope should come from syncscope + flags = ConstantInt(T_i32, mem_flags[:device]) + order = ConstantInt(T_i32, memory_order(inst)) + scope = ConstantInt(T_i32, thread_scope(inst)) + + @dispose builder=LLVM.IRBuilder() begin + position!(builder, inst) + new_inst = call!(builder, intr_ty, intr, [flags, order, scope]) + replace_uses!(inst, new_inst) + erase!(inst) + end + + else + error("Unsupported atomic operation: $inst") + # TODO: backtarace + end + end + + return +end + # generic pointer removal #