Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

FEA: engine accepts dpnp.ndarray and dpt.usm_ndarray objects as input data. #62

Merged
merged 9 commits into from
Nov 29, 2022
5 changes: 2 additions & 3 deletions benchmark/ext_helpers/daal4py.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@


from sklearn.exceptions import NotSupportedByEngineError
from sklearn_numba_dpex.kmeans.engine import KMeansEngine

from sklearn.cluster._kmeans import KMeansCythonEngine

# TODO: instead of relying on monkey patching the default engine, find a way to
# register a distinct entry point that can load a distinct engine outside of setup.py
# (impossible ?)
class DAAL4PYEngine(KMeansEngine):
class DAAL4PYEngine(KMeansCythonEngine):
def prepare_fit(self, X, y=None, sample_weight=None):
if sample_weight is not None and any(sample_weight != sample_weight[0]):
raise NotSupportedByEngineError(
Expand Down
4 changes: 1 addition & 3 deletions benchmark/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def timeit(self, name, engine_provider=None, is_slow=False):

print(
f"Running {name} with parameters sample_weight={self.sample_weight} "
f"n_clusters={n_clusters} data_shape={X.shape} max_iter={max_iter} "
f"tol={tol} ..."
f"n_clusters={n_clusters} data_shape={X.shape} max_iter={max_iter}..."
)

with sklearn.config_context(engine_provider=engine_provider):
Expand Down Expand Up @@ -178,7 +177,6 @@ def _check_same_fit(self, estimator, name, max_iter, assert_allclose):
init = "k-means++" # "k-means++" or "random"
n_clusters = 127
max_iter = 100
tol = 0
skip_slow = True
dtype = np.float32
# NB: it seems that currently the estimators in the benchmark always return
Expand Down
157 changes: 109 additions & 48 deletions sklearn_numba_dpex/common/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def initialize_to_zeros(data):


@lru_cache
def make_broadcast_division_1d_2d_kernel(size0, size1, work_group_size):
def make_broadcast_division_1d_2d_axis0_kernel(size0, size1, work_group_size):
global_size = math.ceil(size1 / work_group_size) * work_group_size

# NB: inplace. # Optimized for C-contiguous array and for
Expand All @@ -100,6 +100,34 @@ def broadcast_division(dividend_array, divisor_vector):
return broadcast_division[global_size, work_group_size]


@lru_cache
def make_broadcast_ops_1d_2d_axis1_kernel(size0, size1, ops, work_group_size):
"""
ops must be a function that will be interpreted as a dpex.func and is subject to
the same rules. It is expected to take two scalar arguments and return one scalar
value. lambda functions are advised against since the cache will not work with lamda
functions."""

global_size = math.ceil(size1 / work_group_size) * work_group_size
ops = dpex.func(ops)

# NB: inplace. # Optimized for C-contiguous array and for
# size1 >> preferred_work_group_size_multiple
@dpex.kernel
def broadcast_ops(left_operand_array, right_operand_vector):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth indicating that the left operand is modified in place and that the right one isn't modified at all.

col_idx = dpex.get_global_id(zero_idx)

if col_idx >= size1:
return

for row_idx in range(size0):
left_operand_array[row_idx, col_idx] = ops(
left_operand_array[row_idx, col_idx], right_operand_vector[row_idx]
)

return broadcast_ops[global_size, work_group_size]


@lru_cache
def make_half_l2_norm_2d_axis0_kernel(size0, size1, work_group_size, dtype):
global_size = math.ceil(size1 / work_group_size) * work_group_size
Expand Down Expand Up @@ -132,38 +160,9 @@ def half_l2_norm(


@lru_cache
def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dtype):
"""Implement data_2d.sum(axis=1) or data_1d.sum()

numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch
a reduction strategy. The strategy relies on the commutativity of the operation used
for the reduction, thus allowing to reduce the input in any order.

The strategy consists in performing local reductions in each work group using local
memory where each work item combine two values, thus halving the number of values,
and the number of active work items. At each iteration the work items are discarded
in a bracket manner. The work items with the greatest ids are discarded first, and
we rely on the fact that the remaining work items are adjacents to optimize the RW
operations.

Once the reduction is done in a work group the result is written in global memory,
thus creating an intermediary result whose size is divided by
`2 * work_group_size`. This is repeated as many time as needed until only one value
remains in global memory.

Notes
-----
`work_group_size` is assumed to be a power of 2.

if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None
then the expected shape of input tensors is `(size0, size1)`, and the reduction
operation is equivalent to input.sum(axis=1). In this case, the kernel is a good
choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in
the same order of magnitude than `preferred_work_group_size_multiple`. If not,
other reduction implementations might give better performances.
"""
check_power_of_2(work_group_size)

def _make_partial_sum_reduction_2d_axis1_kernel(
n_rows, work_group_size, fused_unary_func, dtype
):
# Number of iteration in each execution of the kernel:
local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1)

Expand All @@ -172,12 +171,16 @@ def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dt
minus_one_idx = np.int64(-1)
two_as_a_long = np.int64(2)

is_1d = size1 is None
if fused_unary_func is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document fused_unary_func, please?

Copy link
Collaborator Author

@fcharras fcharras Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's documented in the public function that follows. For readability it would be better I think to swap the definitions of those, but it would have increased the diff and masked the true diff.


def fused_unary_func(x):
return x

fused_unary_func = dpex.func(fused_unary_func)

# TODO: this set of kernel functions could be abstracted away to other coalescing
# functions
if is_1d:
sum_axis_size = size0
n_rows = np.int64(1)
if n_rows is None: # 1d
local_values_size = work_group_size

@dpex.func
Expand All @@ -186,7 +189,7 @@ def set_col_to_zero(array, i):

@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
to_array[to_col] = from_array[from_col]
to_array[to_col] = fused_unary_func(from_array[from_col])

@dpex.func
def add_cols(
Expand All @@ -196,7 +199,9 @@ def add_cols(
to_array,
to_col,
):
to_array[to_col] = from_array[left_from_col] + from_array[right_from_col]
to_array[to_col] = fused_unary_func(
from_array[left_from_col]
) + fused_unary_func(from_array[right_from_col])

@dpex.func
def add_cols_inplace(
Expand All @@ -211,9 +216,7 @@ def add_first_cols(from_array, to_array, to_col):
to_array[to_col] = from_array[zero_idx] + from_array[one_idx]

else:
sum_axis_size = size1
n_rows = size0
local_values_size = (size0, work_group_size)
local_values_size = (n_rows, work_group_size)

@dpex.func
def set_col_to_zero(array, i):
Expand All @@ -223,7 +226,7 @@ def set_col_to_zero(array, i):
@dpex.func
def copy_col(from_array, from_col, to_array, to_col):
for row in range(n_rows):
to_array[row, to_col] = from_array[row, from_col]
to_array[row, to_col] = fused_unary_func(from_array[row, from_col])

@dpex.func
def add_cols(
Expand All @@ -234,9 +237,9 @@ def add_cols(
to_col,
):
for row in range(n_rows):
to_array[row, to_col] = (
from_array[row, left_from_col] + from_array[row, right_from_col]
)
to_array[row, to_col] = fused_unary_func(
from_array[row, left_from_col]
) + fused_unary_func(from_array[row, right_from_col])

@dpex.func
def add_cols_inplace(
Expand Down Expand Up @@ -304,6 +307,62 @@ def partial_sum_reduction(
if first_work_id:
add_first_cols(local_values, result, group_id)

return partial_sum_reduction


@lru_cache
def make_sum_reduction_2d_axis1_kernel(
size0, size1, work_group_size, device, dtype, fused_unary_func=None
Copy link
Collaborator Author

@fcharras fcharras Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this kernel are for enabling the addition of this new argument fused_unary_func which enable fusing a unary ops on the elements of the input before summing. It's used in this PR to compute variance for scaling tolerance.

):
"""Implement data_2d.sum(axis=1) or data_1d.sum()

numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch
a reduction strategy. The strategy relies on the commutativity of the operation used
for the reduction, thus allowing to reduce the input in any order.

The strategy consists in performing local reductions in each work group using local
memory where each work item combine two values, thus halving the number of values,
and the number of active work items. At each iteration the work items are discarded
in a bracket manner. The work items with the greatest ids are discarded first, and
we rely on the fact that the remaining work items are adjacents to optimize the RW
operations.

Once the reduction is done in a work group the result is written in global memory,
thus creating an intermediary result whose size is divided by
`2 * work_group_size`. This is repeated as many time as needed until only one value
remains in global memory.

If fused_unary_func is not None, it will be applied element-wise before summing.
It must be a function that will be interpreted as a dpex.func and is subject to the
same rules. It is expected to take one scalar argument and returning one scalar
value. lambda functions are advised against since the cache will not work with
lambda functions.

Notes
-----
`work_group_size` is assumed to be a power of 2.

if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None
then the expected shape of input tensors is `(size0, size1)`, and the reduction
operation is equivalent to input.sum(axis=1). In this case, the kernel is a good
choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in
the same order of magnitude than `preferred_work_group_size_multiple`. If not,
other reduction implementations might give better performances.
"""
check_power_of_2(work_group_size)
n_rows = size0 if size1 is not None else None
sum_axis_size = size0 if n_rows is None else size1

# fused_unary_func is applied elementwise during the first pass on data, in the
# first kernel execution only.
fused_func_kernel = _make_partial_sum_reduction_2d_axis1_kernel(
n_rows, work_group_size, fused_unary_func, dtype
)
# subsequent kernel calls only sum the data.
nofunc_kernel = _make_partial_sum_reduction_2d_axis1_kernel(
n_rows, work_group_size, None, dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_rows, work_group_size, None, dtype
n_rows, work_group_size, fused_unary_func=None, dtype=dtype

)

# As many partial reductions as necessary are chained until only one element
# remains.
kernels_and_empty_tensors_pairs = []
Expand All @@ -312,18 +371,20 @@ def partial_sum_reduction(
# running the reduction iteration. At this point the loop should stop and then a
# single work item should iterates one time on the remaining values to finish the
# reduction.
kernel = fused_func_kernel
while n_groups > 1:
n_groups = math.ceil(n_groups / (2 * work_group_size))
global_size = n_groups * work_group_size
kernel = partial_sum_reduction[global_size, work_group_size]
result_shape = n_groups if is_1d else (n_rows, n_groups)
kernel = kernel[global_size, work_group_size]
result_shape = n_groups if n_rows is None else (n_rows, n_groups)
# NB: here memory for partial results is allocated ahead of time and will only
# be garbage collected when the instance of `sum_reduction` is garbage
# collected. Thus it can be more efficient to re-use a same instance of
# `sum_reduction` (e.g within iterations of a loop) since it avoid deallocation
# and reallocation every time.
result = dpt.empty(result_shape, dtype=dtype, device=device)
kernels_and_empty_tensors_pairs.append((kernel, result))
kernel = nofunc_kernel

def sum_reduction(summands):
# TODO: manually dispatch the kernels with a SyclQueue
Expand Down
Loading