Skip to content

Rework occupancy, re-enable grid-stride broadcast #367

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions src/device/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,47 @@ host to influence how the kernel is executed. The following keyword arguments ar

- `target::AbstractArray`: specify which array object to use for determining execution
properties (defaults to the first argument `arg0`).
- `total_threads::Int`: how many threads should be launched _in total_. The actual number of
threads and blocks is determined using a heuristic. Defaults to the length of `arg0` if
no other keyword arguments that influence the launch configuration are specified.
- `elements::Int`: how many elements will be processed by this kernel. In most
circumstances, this will correspond to the total number of threads that needs to be
launched, unless the kernel supports a variable number of elements to process per
iteration. Defaults to the length of `arg0` if no other keyword arguments that influence
the launch configuration are specified.
- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
launched. This cannot be used in combination with the `total_threads` argument.
- `name::String`: inform the back end about the name of the kernel to be executed.
This can be used to emit better diagnostics, and is useful with anonymous kernels.
launched. This cannot be used in combination with the `elements` argument.
- `name::String`: inform the back end about the name of the kernel to be executed. This can
be used to emit better diagnostics, and is useful with anonymous kernels.
"""
function gpu_call(kernel::F, args::Vararg{Any,N};
target::AbstractArray=first(args),
total_threads::Union{Int,Nothing}=nothing,
elements::Union{Int,Nothing}=nothing,
threads::Union{Int,Nothing}=nothing,
blocks::Union{Int,Nothing}=nothing,
name::Union{String,Nothing}=nothing) where {F,N}
# non-trivial default values for launch configuration
if total_threads===nothing && threads===nothing && blocks===nothing
total_threads = length(target)
elseif total_threads===nothing
if elements===nothing && threads===nothing && blocks===nothing
elements = length(target)
elseif elements===nothing
if threads === nothing
threads = 1
end
if blocks === nothing
blocks = 1
end
elseif threads!==nothing || blocks!==nothing
error("Cannot specify both total_threads and threads/blocks configuration")
error("Cannot specify both elements and threads/blocks configuration")
end

if total_threads !== nothing
@assert total_threads > 0
heuristic = launch_heuristic(backend(target), kernel, args...)
config = launch_configuration(backend(target), heuristic, total_threads)
# the number of elements to process needs to be passed to the kernel somehow, so there's
# no easy way to do this without passing additional arguments or changing the context.
# both are expensive, so require manual use of `launch_heuristic` for those kernels.
elements_per_thread = 1

if elements !== nothing
@assert elements > 0
heuristic = launch_heuristic(backend(target), kernel, args...;
elements, elements_per_thread)
config = launch_configuration(backend(target), heuristic;
elements, elements_per_thread)
gpu_call(backend(target), kernel, args, config.threads, config.blocks; name=name)
else
@assert threads > 0
Expand All @@ -68,29 +77,29 @@ function gpu_call(kernel::F, args::Vararg{Any,N};
end
end

# how many threads and blocks this kernel need to fully saturate the GPU.
# this can be specialised if more sophisticated heuristics are available.
# how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
# to fully saturate the GPU. `elements` indicates the number of elements that needs to be
# processed, while `elements_per_threads` indicates the number of elements this kernel can
# process (i.e. if it's a grid-stride kernel, or 1 if otherwise).
#
# the `maximize_blocksize` indicates whether the kernel benifits from a large block size
# this heuristic should be specialized for the back-end, ideally using an API for maximizing
# the occupancy of the launch configuration (like CUDA's occupancy API).
function launch_heuristic(backend::AbstractGPUBackend, kernel, args...;
maximize_blocksize=false)
elements::Int, elements_per_thread::Int)
return (threads=256, blocks=32)
end

# determine how many threads and blocks to actually launch given upper limits.
# returns a tuple of blocks, threads, and elements_per_thread (which is always 1
# unless specified that the kernel can handle a number of elements per thread)
function launch_configuration(backend::AbstractGPUBackend, heuristic,
elements::Int, elements_per_thread::Int=1)
function launch_configuration(backend::AbstractGPUBackend, heuristic;
elements::Int, elements_per_thread::Int)
threads = clamp(elements, 1, heuristic.threads)
blocks = max(cld(elements, threads), 1)

# FIXME: use grid-stride loop when we can't launch the number of blocks we need

if false && elements_per_thread > 1 && blocks > heuristic.blocks
if elements_per_thread > 1 && blocks > heuristic.blocks
# we want to launch more blocks than required, so prefer a grid-stride loop instead
# NOTE: this does not seem to improve performance
nelem = clamp(cld(blocks, heuristic.blocks), 1, elements_per_thread)
nelem = clamp(fld(blocks, heuristic.blocks), 1, elements_per_thread)
blocks = cld(blocks, nelem)
(threads=threads, blocks=blocks, elements_per_thread=nelem)
else
Expand Down
4 changes: 2 additions & 2 deletions src/host/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function Base.copyto!(dest::AnyGPUArray, dstart::Integer,

gpu_call(linear_copy_kernel!,
dest, dstart, src, sstart, n;
total_threads=n)
elements=n)
return dest
end

Expand Down Expand Up @@ -188,7 +188,7 @@ function Base.copyto!(dest::AnyGPUArray{<:Any, N}, destcrange::CartesianIndices{
src_offsets = first(srccrange) - oneunit(CartesianIndex{N})
gpu_call(cartesian_copy_kernel!,
dest, dest_offsets, src, src_offsets, shape, len;
total_threads=len)
elements=len)
dest
end

Expand Down
4 changes: 2 additions & 2 deletions src/host/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
if length(b) == 0
return b
end
gpu_call(b, a, o, p, m, n; total_threads=n) do ctx, b, a, o, p, m, n
gpu_call(b, a, o, p, m, n; elements=n) do ctx, b, a, o, p, m, n
j = linear_index(ctx)
j > n && return
d = (j - 1) * p + 1
Expand All @@ -29,7 +29,7 @@ function Base.repeat(a::AbstractGPUVector, m::Int)
if length(b) == 0
return b
end
gpu_call(b, a, o, m; total_threads=m) do ctx, b, a, o, m
gpu_call(b, a, o, m; elements=m) do ctx, b, a, o, m
i = linear_index(ctx)
i > m && return
c = (i - 1)*o + 1
Expand Down
16 changes: 12 additions & 4 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ end
end
return
end
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1)
config = launch_configuration(backend(dest), heuristic, length(dest), typemax(Int))
elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)

Expand Down Expand Up @@ -121,8 +125,12 @@ function Base.map!(f, dest::BroadcastGPUArray, xs::AbstractArray...)
end
return
end
heuristic = launch_heuristic(backend(dest), map_kernel, dest, bc, 1)
config = launch_configuration(backend(dest), heuristic, common_length, typemax(Int))
elements = common_length
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), map_kernel, dest, bc, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(map_kernel, dest, bc, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)

Expand Down
6 changes: 3 additions & 3 deletions src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
res = similar(T, dims)
fill!(res, zero(U))
gpu_call(identity_kernel, res, size(res, 1), s.λ; total_threads=minimum(dims))
gpu_call(identity_kernel, res, size(res, 1), s.λ; elements=minimum(dims))
res
end

Expand All @@ -34,7 +34,7 @@ end

function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
fill!(A, zero(T))
gpu_call(identity_kernel, A, size(A, 1), s.λ; total_threads=minimum(size(A)))
gpu_call(identity_kernel, A, size(A, 1), s.λ; elements=minimum(size(A)))
A
end

Expand All @@ -43,7 +43,7 @@ function _one(unit::T, x::AbstractGPUMatrix) where {T}
m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
I = similar(x, T)
fill!(I, zero(T))
gpu_call(identity_kernel, I, m, unit; total_threads=m)
gpu_call(identity_kernel, I, m, unit; elements=m)
I
end

Expand Down
2 changes: 1 addition & 1 deletion src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ function _setindex!(dest::AbstractGPUArray, src, Is...)
AT = typeof(dest).name.wrapper
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
total_threads=len)
elements=len)
return dest
end

Expand Down
2 changes: 1 addition & 1 deletion src/host/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
threads = (length(A) - 1) ÷ 2 + 1
length(A) == 0 && return
gpu_call(A, rng.state; total_threads = threads) do ctx, a, randstates
gpu_call(A, rng.state; elements = threads) do ctx, a, randstates
idx = 2*(linear_index(ctx) - 1) + 1
U1 = gpu_rand(T, ctx, randstates)
U2 = gpu_rand(T, ctx, randstates)
Expand Down
16 changes: 8 additions & 8 deletions src/host/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ for (t1, t2) in unittriangularwrappers
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
gpu_call(kernel_unittriangular, B, J, one(eltype(B)), min_size; total_threads=min_size)
gpu_call(kernel_unittriangular, B, J, one(eltype(B)), min_size; elements=min_size)
return $t2(B)
end

function (-)(J::UniformScaling, A::$t1{T, <:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .- parent(A)
min_size = minimum(size(B))
gpu_call(kernel_unittriangular, B, J, -one(eltype(B)), min_size; total_threads=min_size)
gpu_call(kernel_unittriangular, B, J, -one(eltype(B)), min_size; elements=min_size)
return $t2(B)
end
end
Expand All @@ -54,15 +54,15 @@ for t in genericwrappers
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return $t(B)
end

function (-)(J::UniformScaling, A::$t{T, <:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .- parent(A)
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return $t(B)
end
end
Expand All @@ -73,15 +73,15 @@ function (+)(A::Hermitian{T,<:AbstractGPUMatrix}, J::UniformScaling{<:Complex})
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return B
end

function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,<:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .-parent(A)
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return B
end

Expand All @@ -90,14 +90,14 @@ function (+)(A::AbstractGPUMatrix{T}, J::UniformScaling) where T
B = similar(A, typeof(oneunit(T) + J))
copyto!(B, A)
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return B
end

function (-)(J::UniformScaling, A::AbstractGPUMatrix{T}) where T
B = similar(A, typeof(J - oneunit(T)))
B .= .-A
min_size = minimum(size(B))
gpu_call(kernel_generic, B, J, min_size; total_threads=min_size)
gpu_call(kernel_generic, B, J, min_size; elements=min_size)
return B
end
2 changes: 1 addition & 1 deletion test/testsuite/gpuinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
end
@test all(x-> x == 2, Array(x))

gpu_call(x; total_threads=N) do ctx, x
gpu_call(x; elements=N) do ctx, x
x[linear_index(ctx)] = 2
return
end
Expand Down