Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Revert "Split translate function"
Browse files Browse the repository at this point in the history
This reverts commit 3f64767.
  • Loading branch information
thomasfaingnaert committed May 25, 2020
1 parent 3f64767 commit 2acd20b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 23 deletions.
8 changes: 4 additions & 4 deletions src/device/matmul_kernels/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function matmul_impl(a, b, c, d,

@unroll for i = 1 : NUM_FRAGMENTS_M
@unroll for j = 1 : NUM_FRAGMENTS_N
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
end
end
Expand Down Expand Up @@ -84,15 +84,15 @@ function matmul_impl(a, b, c, d,
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)

@unroll for i = 1 : NUM_FRAGMENTS_M
a_tile = translate_const(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
end

# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)

@unroll for j = 1 : NUM_FRAGMENTS_N
b_tile = translate_const(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
end

Expand All @@ -114,7 +114,7 @@ function matmul_impl(a, b, c, d,

@unroll for i = 1 : NUM_FRAGMENTS_M
@unroll for j = 1 : NUM_FRAGMENTS_N
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
end
end
Expand Down
20 changes: 1 addition & 19 deletions src/device/tiling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export translate
"""
translate(tile::Tile{names, T}, offset::NamedTuple{names, T})
Translate (i.e. move) a [`Tile`](@ref) by an `offset`.
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
# Arguments
- `tile`: The [`Tile`](@ref) to translate.
Expand All @@ -132,24 +132,6 @@ end

@inline translate(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate(tile, NamedTuple{names}(offset))

export translate_const

"""
translate_const(tile::Tile{names, T}, offset::NamedTuple{names, T})
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
# Arguments
- `tile`: The [`Tile`](@ref) to translate.
- `offset`: The `offset` in each dimension.
"""
@inline function translate_const(tile::Tile{size, names, T}, offset::NamedTuple{names, T}) where {names, T, size}
offset = map(+, tile.offset, offset)
return Tile{size, names, T}(tile.base, offset)
end

@inline translate_const(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate_const(tile, NamedTuple{names}(offset))

# -------------
# TileIterators
# -------------
Expand Down

0 comments on commit 2acd20b

Please # to comment.