From c0e956b68a224677678a2e310cf6a713edcbc058 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Wed, 23 Nov 2022 18:18:36 +0100 Subject: [PATCH 1/7] When using KMeans with sklearn_numba_dpex engine, it's now possible to pass dpnp.ndarray and dpt.usm_ndarray objects as input data. --- benchmark/ext_helpers/daal4py.py | 47 +- benchmark/kmeans.py | 4 +- sklearn_numba_dpex/common/kernels.py | 157 +++++-- sklearn_numba_dpex/kmeans/drivers.py | 79 +++- sklearn_numba_dpex/kmeans/engine.py | 424 ++++++++++-------- .../kmeans/tests/test_kmeans.py | 47 +- sklearn_numba_dpex/testing/config.py | 23 +- 7 files changed, 490 insertions(+), 291 deletions(-) diff --git a/benchmark/ext_helpers/daal4py.py b/benchmark/ext_helpers/daal4py.py index a2955dc..a94ce3d 100644 --- a/benchmark/ext_helpers/daal4py.py +++ b/benchmark/ext_helpers/daal4py.py @@ -1,33 +1,38 @@ import warnings import sklearn + # HACK: daal4py will fail to import with too recent versions of sklearn because # of this missing attribute. Let's pretend it still exists. - -if not hasattr(sklearn.neighbors._base, "_check_weights"): - warnings.warn( - f"The current version of scikit-learn ( =={sklearn.__version__} ) is too " - "recent to ensure good compatibility with sklearn intelex, who only supports " - "sklearn >=0.22, <1.1 . Use very cautiously, things might not work as " - "expected...", - RuntimeWarning, - ) - try: +try: + if not hasattr(sklearn.neighbors._base, "_check_weights"): + warnings.warn( + f"The current version of scikit-learn ( =={sklearn.__version__} ) is too " + "recent to ensure good compatibility with sklearn intelex, who only supports " + "sklearn >=0.22, <1.1 . Use very cautiously, things might not work as " + "expected...", + RuntimeWarning, + ) + _daal4py_kmeans_compat_mode = True sklearn.neighbors._base._check_weights = None - from daal4py.sklearn.cluster._k_means_0_23 import KMeans - finally: + from daal4py.sklearn.cluster._k_means_0_23 import ( + KMeans, + _daal4py_compute_starting_centroids, + getFPType, + ) +finally: + if _daal4py_kmeans_compat_mode: del sklearn.neighbors._base._check_weights -else: - from daal4py.sklearn.cluster._k_means_0_23 import KMeans + from sklearn.exceptions import NotSupportedByEngineError -from sklearn_numba_dpex.kmeans.engine import KMeansEngine +from sklearn.cluster._kmeans import KMeansCythonEngine # TODO: instead of relying on monkey patching the default engine, find a way to # register a distinct entry point that can load a distinct engine outside of setup.py # (impossible ?) -class DAAL4PYEngine(KMeansEngine): +class DAAL4PYEngine(KMeansCythonEngine): def prepare_fit(self, X, y=None, sample_weight=None): if sample_weight is not None and any(sample_weight != sample_weight[0]): raise NotSupportedByEngineError( @@ -37,7 +42,15 @@ def prepare_fit(self, X, y=None, sample_weight=None): return super().prepare_fit(X, y, sample_weight) def init_centroids(self, X): - return super(KMeansEngine, self).init_centroids(X) + _, centroids = _daal4py_compute_starting_centroids( + X, + getFPType(X), + self.estimator.n_clusters, + self.estimator.init, + self.estimator.verbose, + self.random_state, + ) + return centroids def kmeans_single(self, X, sample_weight, centers_init): diff --git a/benchmark/kmeans.py b/benchmark/kmeans.py index c26276d..b40b716 100644 --- a/benchmark/kmeans.py +++ b/benchmark/kmeans.py @@ -98,8 +98,7 @@ def timeit(self, name, engine_provider=None, is_slow=False): print( f"Running {name} with parameters sample_weight={self.sample_weight} " - f"n_clusters={n_clusters} data_shape={X.shape} max_iter={max_iter} " - f"tol={tol} ..." + f"n_clusters={n_clusters} data_shape={X.shape} max_iter={max_iter}..." ) with sklearn.config_context(engine_provider=engine_provider): @@ -178,7 +177,6 @@ def _check_same_fit(self, estimator, name, max_iter, assert_allclose): init = "k-means++" # "k-means++" or "random" n_clusters = 127 max_iter = 100 - tol = 0 skip_slow = True dtype = np.float32 # NB: it seems that currently the estimators in the benchmark always return diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index 0b801dd..b0ab71e 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -78,7 +78,7 @@ def initialize_to_zeros(data): @lru_cache -def make_broadcast_division_1d_2d_kernel(size0, size1, work_group_size): +def make_broadcast_division_1d_2d_axis0_kernel(size0, size1, work_group_size): global_size = math.ceil(size1 / work_group_size) * work_group_size # NB: inplace. # Optimized for C-contiguous array and for @@ -100,6 +100,42 @@ def broadcast_division(dividend_array, divisor_vector): return broadcast_division[global_size, work_group_size] +@lru_cache +def make_broadcast_ops_1d_2d_axis1_kernel(size0, size1, ops, work_group_size): + global_size = math.ceil(size1 / work_group_size) * work_group_size + + if ops == "plus": + + @dpex.func + def ops(augend, addend): + return augend + addend + + elif ops == "minus": + + @dpex.func + def ops(minuend, subtrahend): + return minuend - subtrahend + + else: + raise ValueError(f"Invalid ops: {ops} .") + + # NB: inplace. # Optimized for C-contiguous array and for + # size1 >> preferred_work_group_size_multiple + @dpex.kernel + def broadcast_ops(left_operand_array, right_operand_vector): + col_idx = dpex.get_global_id(zero_idx) + + if col_idx >= size1: + return + + for row_idx in range(size0): + left_operand_array[row_idx, col_idx] = ops( + left_operand_array[row_idx, col_idx], right_operand_vector[row_idx] + ) + + return broadcast_ops[global_size, work_group_size] + + @lru_cache def make_half_l2_norm_2d_axis0_kernel(size0, size1, work_group_size, dtype): global_size = math.ceil(size1 / work_group_size) * work_group_size @@ -132,38 +168,9 @@ def half_l2_norm( @lru_cache -def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dtype): - """Implement data_2d.sum(axis=1) or data_1d.sum() - - numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch - a reduction strategy. The strategy relies on the commutativity of the operation used - for the reduction, thus allowing to reduce the input in any order. - - The strategy consists in performing local reductions in each work group using local - memory where each work item combine two values, thus halving the number of values, - and the number of active work items. At each iteration the work items are discarded - in a bracket manner. The work items with the greatest ids are discarded first, and - we rely on the fact that the remaining work items are adjacents to optimize the RW - operations. - - Once the reduction is done in a work group the result is written in global memory, - thus creating an intermediary result whose size is divided by - `2 * work_group_size`. This is repeated as many time as needed until only one value - remains in global memory. - - Notes - ----- - `work_group_size` is assumed to be a power of 2. - - if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None - then the expected shape of input tensors is `(size0, size1)`, and the reduction - operation is equivalent to input.sum(axis=1). In this case, the kernel is a good - choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in - the same order of magnitude than `preferred_work_group_size_multiple`. If not, - other reduction implementations might give better performances. - """ - check_power_of_2(work_group_size) - +def _make_partial_sum_reduction_2d_axis1_kernel( + n_rows, work_group_size, fused_unary_func, dtype +): # Number of iteration in each execution of the kernel: local_n_iterations = np.int64(math.floor(math.log2(work_group_size)) - 1) @@ -172,12 +179,17 @@ def make_sum_reduction_2d_axis1_kernel(size0, size1, work_group_size, device, dt minus_one_idx = np.int64(-1) two_as_a_long = np.int64(2) - is_1d = size1 is None + if fused_unary_func is not None: + fused_unary_func = dpex.func(fused_unary_func) + else: + + @dpex.func + def fused_unary_func(x): + return x + # TODO: this set of kernel functions could be abstracted away to other coalescing # functions - if is_1d: - sum_axis_size = size0 - n_rows = np.int64(1) + if n_rows is None: # 1d local_values_size = work_group_size @dpex.func @@ -186,7 +198,7 @@ def set_col_to_zero(array, i): @dpex.func def copy_col(from_array, from_col, to_array, to_col): - to_array[to_col] = from_array[from_col] + to_array[to_col] = fused_unary_func(from_array[from_col]) @dpex.func def add_cols( @@ -196,7 +208,9 @@ def add_cols( to_array, to_col, ): - to_array[to_col] = from_array[left_from_col] + from_array[right_from_col] + to_array[to_col] = fused_unary_func( + from_array[left_from_col] + ) + fused_unary_func(from_array[right_from_col]) @dpex.func def add_cols_inplace( @@ -211,9 +225,7 @@ def add_first_cols(from_array, to_array, to_col): to_array[to_col] = from_array[zero_idx] + from_array[one_idx] else: - sum_axis_size = size1 - n_rows = size0 - local_values_size = (size0, work_group_size) + local_values_size = (n_rows, work_group_size) @dpex.func def set_col_to_zero(array, i): @@ -223,7 +235,7 @@ def set_col_to_zero(array, i): @dpex.func def copy_col(from_array, from_col, to_array, to_col): for row in range(n_rows): - to_array[row, to_col] = from_array[row, from_col] + to_array[row, to_col] = fused_unary_func(from_array[row, from_col]) @dpex.func def add_cols( @@ -234,9 +246,9 @@ def add_cols( to_col, ): for row in range(n_rows): - to_array[row, to_col] = ( - from_array[row, left_from_col] + from_array[row, right_from_col] - ) + to_array[row, to_col] = fused_unary_func( + from_array[row, left_from_col] + ) + fused_unary_func(from_array[row, right_from_col]) @dpex.func def add_cols_inplace( @@ -304,6 +316,53 @@ def partial_sum_reduction( if first_work_id: add_first_cols(local_values, result, group_id) + return partial_sum_reduction + + +@lru_cache +def make_sum_reduction_2d_axis1_kernel( + size0, size1, work_group_size, device, dtype, fused_unary_func=None +): + """Implement data_2d.sum(axis=1) or data_1d.sum() + + numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch + a reduction strategy. The strategy relies on the commutativity of the operation used + for the reduction, thus allowing to reduce the input in any order. + + The strategy consists in performing local reductions in each work group using local + memory where each work item combine two values, thus halving the number of values, + and the number of active work items. At each iteration the work items are discarded + in a bracket manner. The work items with the greatest ids are discarded first, and + we rely on the fact that the remaining work items are adjacents to optimize the RW + operations. + + Once the reduction is done in a work group the result is written in global memory, + thus creating an intermediary result whose size is divided by + `2 * work_group_size`. This is repeated as many time as needed until only one value + remains in global memory. + + Notes + ----- + `work_group_size` is assumed to be a power of 2. + + if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None + then the expected shape of input tensors is `(size0, size1)`, and the reduction + operation is equivalent to input.sum(axis=1). In this case, the kernel is a good + choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in + the same order of magnitude than `preferred_work_group_size_multiple`. If not, + other reduction implementations might give better performances. + """ + check_power_of_2(work_group_size) + n_rows = size0 if size1 is not None else None + sum_axis_size = size0 if n_rows is None else size1 + + fused_func_kernel = _make_partial_sum_reduction_2d_axis1_kernel( + n_rows, work_group_size, fused_unary_func, dtype + ) + nofunc_kernel = _make_partial_sum_reduction_2d_axis1_kernel( + n_rows, work_group_size, None, dtype + ) + # As many partial reductions as necessary are chained until only one element # remains. kernels_and_empty_tensors_pairs = [] @@ -312,11 +371,12 @@ def partial_sum_reduction( # running the reduction iteration. At this point the loop should stop and then a # single work item should iterates one time on the remaining values to finish the # reduction. + kernel = fused_func_kernel while n_groups > 1: n_groups = math.ceil(n_groups / (2 * work_group_size)) global_size = n_groups * work_group_size - kernel = partial_sum_reduction[global_size, work_group_size] - result_shape = n_groups if is_1d else (n_rows, n_groups) + kernel = kernel[global_size, work_group_size] + result_shape = n_groups if n_rows is None else (n_rows, n_groups) # NB: here memory for partial results is allocated ahead of time and will only # be garbage collected when the instance of `sum_reduction` is garbage # collected. Thus it can be more efficient to re-use a same instance of @@ -324,6 +384,7 @@ def partial_sum_reduction( # and reallocation every time. result = dpt.empty(result_shape, dtype=dtype, device=device) kernels_and_empty_tensors_pairs.append((kernel, result)) + kernel = nofunc_kernel def sum_reduction(summands): # TODO: manually dispatch the kernels with a SyclQueue diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 74434fb..7c7c940 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -14,7 +14,8 @@ from sklearn_numba_dpex.common.kernels import ( make_initialize_to_zeros_2d_kernel, make_initialize_to_zeros_3d_kernel, - make_broadcast_division_1d_2d_kernel, + make_broadcast_division_1d_2d_axis0_kernel, + make_broadcast_ops_1d_2d_axis1_kernel, make_half_l2_norm_2d_axis0_kernel, make_sum_reduction_2d_axis1_kernel, make_argmin_reduction_1d_kernel, @@ -100,7 +101,7 @@ def lloyd( dtype=compute_dtype, ) - broadcast_division_kernel = make_broadcast_division_1d_2d_kernel( + broadcast_division_kernel = make_broadcast_division_1d_2d_axis0_kernel( size0=n_features, size1=n_clusters, work_group_size=max_work_group_size, @@ -324,7 +325,77 @@ def lloyd( # inertia is now a 1-sized numpy array, we transform it into a scalar: inertia = inertia[0] - return assignments_idx, inertia, centroids_t.T, n_iteration + return assignments_idx, inertia, centroids_t, n_iteration + + +def prepare_data_for_lloyd(X_t, init, tol, copy_x): + n_features, n_samples = X_t.shape + compute_dtype = X_t.dtype.type + + device = X_t.device.sycl_device + max_work_group_size = device.max_work_group_size + + sum_axis1_kernel = make_sum_reduction_2d_axis1_kernel( + X_t.shape[0], + X_t.shape[1], + device.max_work_group_size, + device=device, + dtype=compute_dtype, + ) + + X_mean = (sum_axis1_kernel(X_t) / compute_dtype(n_samples))[:, 0] + + if (X_mean == 0).astype(int).sum() == len(X_mean): + X_mean = None + else: + X_t = dpt.asarray(X_t, copy=copy_x) + broadcast_X_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_samples, + ops="minus", + work_group_size=max_work_group_size, + ) + + broadcast_X_minus_X_mean(X_t, X_mean) + + if isinstance(init, dpt.usm_ndarray): + n_clusters = init.shape[1] + broadcast_init_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_clusters, + ops="minus", + work_group_size=max_work_group_size, + ) + broadcast_init_minus_X_mean(init, X_mean) + + variance_kernel = make_sum_reduction_2d_axis1_kernel( + n_features * n_samples, + None, + max_work_group_size, + device=device, + dtype=compute_dtype, + fused_unary_func=lambda x: x * x, + ) + variance = variance_kernel(dpt.reshape(X_t, -1)) / n_features + tol = variance * tol + + return X_t, X_mean, init, tol + + +def restore_data_after_lloyd(X_t, X_mean): + n_features, n_samples = X_t.shape + + device = X_t.device.sycl_device + max_work_group_size = device.max_work_group_size + + X_t = dpt.asarray(X_t, copy=False) + broadcast_X_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_samples, + ops="plus", + work_group_size=max_work_group_size, + ) + broadcast_X_plus_X_mean(X_t, X_mean) def _relocate_empty_clusters( @@ -663,4 +734,4 @@ def kmeans_plusplus( centers_t[:, c] = X_t[:, center_index] center_indices[c] = center_index - return centers_t.T, center_indices + return centers_t, center_indices diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index 41338e5..f8c381c 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -1,14 +1,29 @@ -import warnings +import numbers +import contextlib +import importlib import numpy as np import dpctl import dpctl.tensor as dpt +import sklearn +import sklearn.utils.validation as sklearn_validation from sklearn.cluster._kmeans import KMeansCythonEngine +from sklearn.utils import check_random_state, check_array +from sklearn.utils.validation import _is_arraylike_not_scalar -from sklearn.exceptions import NotSupportedByEngineError, DataConversionWarning +from sklearn.exceptions import NotSupportedByEngineError -from .drivers import lloyd, get_labels_inertia, get_euclidean_distances, kmeans_plusplus +from sklearn_numba_dpex.testing.config import override_attr_context + +from .drivers import ( + prepare_data_for_lloyd, + lloyd, + restore_data_after_lloyd, + get_labels_inertia, + get_euclidean_distances, + kmeans_plusplus, +) class _IgnoreSampleWeight: @@ -51,32 +66,29 @@ class KMeansEngine(KMeansCythonEngine): def __init__(self, estimator): self.device = dpctl.SyclDevice(self._CONFIG.get("device")) + + # NB: numba_dpex kernels only currently supports working with C memory layout + # (see https://github.com/IntelPython/numba-dpex/issues/767) but our KMeans + # implementation is hypothetized to be more efficient with the F-memory layout. + # As a workaround the kernels work with the transpose of X, X_t, where X_t + # is created with a C layout, which results in equivalent memory access + # patterns than with a F layout for X. + # TODO: when numba_dpex supports inputs with F-layout: + # - use X rather than X_t and adapt the codebase (better for readability and + # more consistent with sklearn notations) + # - test the performances with both layouts and use the best performing layout. + order = self._CONFIG.get("order", "F") + if order != "F": + raise ValueError( + "Kernels compiled by numba_dpex called on an input array with " + "the Fortran memory layout silently return incorrect results: " + "https://github.com/IntelPython/numba-dpex/issues/767" + ) + self.order = order super().__init__(estimator) def prepare_fit(self, X, y=None, sample_weight=None): estimator = self.estimator - try: - # This pass of data validation only aims at detecting input types that are - # supported by the sklearn engine but not by KMeansEngine. For those inputs - # we raise a NotSupportedByEngineError exception. - # estimator._validate_data is called again in super().prepare_fit later on - # and will raise ValueError or TypeError for data that is not even - # compatible with the sklearn engine. - estimator._validate_data( - X, - accept_sparse=False, - dtype=None, - force_all_finite=False, - ensure_2d=False, - allow_nd=True, - ensure_min_samples=0, - ensure_min_features=0, - estimator=estimator, - ) - except Exception as e: - raise NotSupportedByEngineError( - "The sklearn_nunmba_dpex engine for KMeans does not support the format of the inputed data." - ) from e algorithm = estimator.algorithm if algorithm not in ("lloyd", "auto", "full"): @@ -84,64 +96,97 @@ def prepare_fit(self, X, y=None, sample_weight=None): f"The sklearn_nunmba_dpex engine for KMeans only support the Lloyd algorithm, {algorithm} is not supported." ) - self.sample_weight = sample_weight + # NB: self.estimator.copy_x is enforced later on only if X_mean has at + # least one non-null value + X = self._validate_data(X) + estimator._check_params_vs_input(X) - return super().prepare_fit(X, y, sample_weight) + self.sample_weight = self._check_sample_weight(sample_weight, X) - def init_centroids(self, X): - init = self.init + init = self.estimator.init + init_is_array_like = _is_arraylike_not_scalar(init) + if init_is_array_like: + init = self._check_init(init, X) + + X_t, X_mean, self.init, self.tol = prepare_data_for_lloyd( + X.T, init, estimator.tol, estimator.copy_x + ) - if isinstance(init, str) and init == "k-means++": - centers, _ = self._kmeans_plusplus(X) + self.X_mean = X_mean - else: - centers = self.estimator._init_centroids( - X, - x_squared_norms=self.x_squared_norms, - init=init, - random_state=self.random_state, - ) - return centers + self.random_state = check_random_state(estimator.random_state) - def _kmeans_plusplus(self, X): - n_clusters = self.estimator.n_clusters + return X_t.T, y, self.sample_weight - if X.shape[0] < n_clusters: - raise ValueError( - f"n_samples={X.shape[0]} should be >= n_clusters={n_clusters}." - ) + def unshift_centers(self, X, best_centers): + if (X_mean := self.X_mean) is None: + return - X, sample_weight, _, output_dtype = self._check_inputs( - X, self.sample_weight, cluster_centers=None - ) + best_centers += dpt.asnumpy(X_mean.get_array()) - X_t, sample_weight, _ = self._load_transposed_data_to_device( - X, sample_weight, cluster_centers=None - ) + # NB: self.estimator.copy_x being set to False does not mean that no copy + # actually happened, only that no copy was forced if it was not necessary + # with respect to what device, dtype and order that are required at compute + # time. Nevertheless, there's no simple way to check if a copy happened + # without assumptions on the type of the raw input submitted by the user, + # but at the moment it is unknown what those assumptions could be. + # As a result, the following instructions are ran every time, even if it + # isn't useful when a copy has been made. + # TODO: is there a set of assumptions that exhaustively describe the set + # of accepted inputs, and also enables checking if a copy happened or not + # in a simple way ? + if not self.estimator.copy_x: + restore_data_after_lloyd(X.T, X_mean) - centers, center_indices = kmeans_plusplus( - X_t, sample_weight, n_clusters, self.random_state - ) + def init_centroids(self, X): + init = self.init + n_clusters = self.estimator.n_clusters - centers = dpt.asnumpy(centers).astype(output_dtype, copy=False) - center_indices = dpt.asnumpy(center_indices) - return centers, center_indices + if isinstance(init, dpt.usm_ndarray): + centers_t = init - def kmeans_single(self, X, sample_weight, centers_init): - X, sample_weight, cluster_centers, output_dtype = self._check_inputs( - X, sample_weight, centers_init - ) + elif isinstance(init, str) and init == "k-means++": + centers_t, _ = self._kmeans_plusplus(X) + + elif callable(init): + centers = init(X, self.estimator.n_clusters, random_state=self.random_state) + centers_t = self._check_init(centers, X) + + else: + # NB: sampling without replacement must be executed sequentially so + # it's better adapted to CPU + centers_idx = self.random_state.choice( + X.shape[0], size=n_clusters, replace=False + ) + # Poor man's fancy indexing + # TODO: write a kernel ? or replace with better equivalent when available ? + centers_t = dpt.concat( + [dpt.expand_dims(X[center_idx], axes=1) for center_idx in centers_idx], + axis=1, + ) - use_uniform_weights = (sample_weight == sample_weight[0]).all() + return centers_t - X_t, sample_weight, centroids_t = self._load_transposed_data_to_device( - X, sample_weight, cluster_centers + def _kmeans_plusplus(self, X): + n_clusters = self.estimator.n_clusters + + centers_t, center_indices = kmeans_plusplus( + X.T, self.sample_weight, n_clusters, self.random_state ) + return centers_t, center_indices + + def kmeans_single(self, X, sample_weight, centers_init_t): + # ???: using `.all()` often segfaults + # TODO: minimal reproducer and issue at dpnp + # or write a kernel ? + use_uniform_weights = (sample_weight == sample_weight[0]).astype( + int + ).sum() == len(sample_weight) assignments_idx, inertia, best_centroids, n_iteration = lloyd( - X_t, + X.T, sample_weight, - centroids_t, + centers_init_t, use_uniform_weights, self.estimator.max_iter, self.estimator.verbose, @@ -156,155 +201,160 @@ def kmeans_single(self, X, sample_weight, centers_init): inertia, # XXX: having a C-contiguous centroid array is expected in sklearn in some # unit test and by the cython engine. - np.ascontiguousarray( - dpt.asnumpy(best_centroids).astype(output_dtype, copy=False) - ), + # ???: rather that returning whatever dtype the driver returns (which might + # depends on device support for float64), shouldn't we cast to a dtype that + # is always consistent with the input ? (e.g. cast to float64 if the input + # was given as float64 ?) But what assumptions can we make on the input + # so we can infer its input dtype without risking triggering a copy of it ? + np.ascontiguousarray(dpt.asnumpy(best_centroids.T)), n_iteration, ) + def prepare_prediction(self, X, sample_weight): + X = self._validate_data(X, reset=False) + sample_weight = self._check_sample_weight(sample_weight, X) + return X, sample_weight + def get_labels(self, X, sample_weight): - labels, _ = self._get_labels_inertia(X, with_inertia=False) + # TODO: sample_weight actually not used for get_labels. Fix in sklearn ? + labels, _ = self._get_labels_inertia(X, sample_weight, with_inertia=False) return dpt.asnumpy(labels).astype(np.int32, copy=False) def get_score(self, X, sample_weight): _, inertia = self._get_labels_inertia(X, sample_weight, with_inertia=True) return inertia - def _get_labels_inertia( - self, X, sample_weight=_IgnoreSampleWeight, with_inertia=True - ): - X, sample_weight, centers, output_dtype = self._check_inputs( - X, - sample_weight=sample_weight, - cluster_centers=self.estimator.cluster_centers_, - ) - - if sample_weight is _IgnoreSampleWeight: - sample_weight = None - - X_t, sample_weight, centroids_t = self._load_transposed_data_to_device( - X, sample_weight, centers + def _get_labels_inertia(self, X, sample_weight, with_inertia=True): + cluster_centers = self._check_init( + self.estimator.cluster_centers_, X, copy=False ) assignments_idx, inertia = get_labels_inertia( - X_t, centroids_t, sample_weight, with_inertia + X.T, cluster_centers, sample_weight, with_inertia ) if with_inertia: # inertia is a 1-sized numpy array, we transform it into a scalar: - inertia = inertia.astype(output_dtype)[0] + inertia = inertia[0] return assignments_idx, inertia + def prepare_transform(self, X): + # TODO: fix fit_transform in sklearn: need to call prepare_transform + # inbetween fit and transform ? or remove prepare_transform ? + return X + def get_euclidean_distances(self, X): - X, _, Y, output_dtype = self._check_inputs( - X, - sample_weight=_IgnoreSampleWeight, - cluster_centers=self.estimator.cluster_centers_, + X = self._validate_data(X, reset=False) + cluster_centers = self._check_init( + self.estimator.cluster_centers_, X, copy=False ) - - X_t, _, Y_t = self._load_transposed_data_to_device(X, None, Y) - - euclidean_distances = get_euclidean_distances(X_t, Y_t) - - return dpt.asnumpy(euclidean_distances).astype(output_dtype, copy=False) - - def _check_inputs(self, X, sample_weight, cluster_centers): - + euclidean_distances = get_euclidean_distances(X.T, cluster_centers) + return dpt.asnumpy(euclidean_distances) + + def _validate_data(self, X, reset=True): + accepted_dtypes = [np.float32] + # NB: one could argue that `float32` is a better default, but sklearn defaults + # to `np.float64` and we is apply the same for consistence. + if self.device.has_aspect_fp64: + accepted_dtypes = [np.float64, np.float32] + else: + accepted_dtypes = [np.float32] + + with _validate_with_array_api(): + try: + X = self.estimator._validate_data( + X, + accept_sparse=False, + dtype=accepted_dtypes, + order=self.order, + copy=False, + reset=reset, + force_all_finite=True, + estimator=self.estimator, + ) + return X + except TypeError as type_error: + if "A sparse matrix was passed, but dense data is required" in str( + type_error + ): + raise NotSupportedByEngineError from TypeError + + def _check_sample_weight(self, sample_weight, X): + """Adapted from sklearn.utils.validation._check_sample_weight to be compatible + with Array API dispatching""" + n_samples = X.shape[0] + dtype = X.dtype if sample_weight is None: - sample_weight = np.ones(len(X), dtype=X.dtype) - - X, sample_weight, cluster_centers, output_dtype = self._set_dtype( - X, sample_weight, cluster_centers - ) - - return X, sample_weight, cluster_centers, output_dtype - - def _set_dtype(self, X, sample_weight, cluster_centers): - output_dtype = compute_dtype = np.dtype(X.dtype).type - copy = True - if (compute_dtype != np.float32) and (compute_dtype != np.float64): - text = ( - f"KMeans has been set to compute with type {compute_dtype} but only " - f"the types float32 and float64 are supported. The computations and " - f"outputs will default back to float32 type." - ) - output_dtype = compute_dtype = np.float32 - elif (compute_dtype == np.float64) and not self.device.has_aspect_fp64: - text = ( - f"KMeans is set to compute with type {compute_dtype} but this type is " - f"not supported by the device {self.device.name}. The computations " - f"will default back to float32 type." - ) - compute_dtype = np.float32 - + sample_weight = dpt.ones(n_samples, dtype=dtype, device=self.device) + elif isinstance(sample_weight, numbers.Number): + sample_weight = dpt.full(n_samples, 1, dtype=dtype, device=self.device) else: - copy = False - - if copy: - text += ( - f" A copy of the data casted to type {compute_dtype} will be created. " - f"To save memory and suppress this warning, ensure that the dtype of " - f"the input data matches the dtype required for computations." - ) - warnings.warn(text, DataConversionWarning) - # TODO: instead of triggering a copy on the host side, we could use the - # dtype to allocate a shared USM buffer and fill it with casted values from - # X. In this case we should only warn when: - # (dtype == np.float64) and not self.has_aspect_fp64 - # The other cases would not trigger any additional memory copies. - X = X.astype(compute_dtype) - - if cluster_centers is not None and ( - (cluster_centers_dtype := cluster_centers.dtype) != compute_dtype - ): - warnings.warn( - f"The centers have been passed with type {cluster_centers_dtype} but " - f"type {compute_dtype} is expected. A copy will be created with the " - f"correct type {compute_dtype}. Ensure that the centers are passed " - f"with the correct dtype to save memory and suppress this warning.", - DataConversionWarning, - ) - cluster_centers = cluster_centers.astype(compute_dtype) - - if (sample_weight is not _IgnoreSampleWeight) and ( - sample_weight.dtype != compute_dtype - ): - warnings.warn( - f"sample_weight has been passed with type {sample_weight.dtype} but " - f"type {compute_dtype} is expected. A copy will be created with the " - f"correct type {compute_dtype}. Ensure that sample_weight is passed " - f"with the correct dtype to save memory and suppress this warning.", - DataConversionWarning, + with _validate_with_array_api(): + sample_weight = check_array( + sample_weight, + accept_sparse=False, + order="C", + dtype=dtype, + force_all_finite=True, + ensure_2d=False, + allow_nd=False, + estimator=self.estimator, + input_name="sample_weight", + ) + + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + sample_weight = dpt.asarray(sample_weight, device=self.device) + + return sample_weight + + def _check_init(self, init, X, copy=False): + with _validate_with_array_api(): + init = check_array( + init, + dtype=X.dtype, + accept_sparse=False, + copy=False, + order=self.order, + force_all_finite=True, + ensure_2d=True, + estimator=self.estimator, + input_name="init", ) - sample_weight = sample_weight.astype(compute_dtype) + self.estimator._validate_center_shape(X, init) + init_t = dpt.asarray(init.T, order="C", copy=False) + return init_t - return X, sample_weight, cluster_centers, output_dtype - def _load_transposed_data_to_device(self, X, sample_weight, cluster_centers): - # Transfer the input data to device memory, - # TODO: let the user pass directly dpt or dpnp arrays to avoid copies. +def _get_namespace(*arrays): + return dpt, True - # NB: numba_dpex kernels only currently supports inputs with a C memory layout - # (see https://github.com/IntelPython/numba-dpex/issues/767) but our KMeans - # implementation is hypothetized to be more efficient with the F-memory layout. - # As a workaround the kernels work with the transpose of X, X_t, where X_t - # is created with a C layout, which results in equivalent memory access - # patterns than with a F layout for X. - # TODO: when numba_dpex supports inputs with F-layout: - # - use X rather than X_t and adapt the codebase (better for readability and - # more consistent with sklearn notations) - # - test the performances with both layouts and use the best performing layout. - X_t = dpt.asarray(X.T, order="C", device=self.device) - assert ( - X_t.strides[1] == 1 - ) # C memory layout, equivalent to Fortran layout on transposed +def _asarray_with_order(array, dtype, order, copy=None, xp=None): + return dpt.asarray(array, dtype=dtype, order=order, copy=copy) - if sample_weight is not None: - sample_weight = dpt.from_numpy(sample_weight, device=self.device) - if cluster_centers is not None: - cluster_centers = dpt.from_numpy(cluster_centers.T, device=self.device) - - return X_t, sample_weight, cluster_centers +@contextlib.contextmanager +def _validate_with_array_api(): + # TODO: when https://github.com/IntelPython/dpctl/issues/997 and + # https://github.com/scikit-learn/scikit-learn/issues/25000 and are solved + # remove those hacks. + with sklearn.config_context( + array_api_dispatch=True, + assume_finite=True # workaround 1: disable force_all_finite + # workaround 2: monkey patch get_namespace and _asarray_with_order to force + # dpctl.tensor array namespace + ), override_attr_context( + sklearn_validation, + get_namespace=_get_namespace, + _asarray_with_order=_asarray_with_order, + ): + yield diff --git a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py index 73b4886..57dafd5 100644 --- a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py +++ b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py @@ -142,7 +142,7 @@ def test_euclidean_distance(dtype): expected = np.sqrt(((a - b) ** 2).sum()) - estimator = KMeans() + estimator = KMeans(n_clusters=len(b)) estimator.cluster_centers_ = b engine = KMeansEngine(estimator) @@ -162,16 +162,16 @@ def test_inertia(dtype): sample_weight = rng.standard_normal(100, dtype=dtype) centers = rng.standard_normal((5, 10), dtype=dtype) - estimator = KMeans() + estimator = KMeans(n_clusters=len(centers)) estimator.cluster_centers_ = centers engine = KMeansEngine(estimator) - - labels = engine.get_labels(X, sample_weight) + X_prepared, sample_weight_prepared = engine.prepare_prediction(X, sample_weight) + labels = engine.get_labels(X_prepared, sample_weight_prepared) distances = ((X - centers[labels]) ** 2).sum(axis=1) expected = np.sum(distances * sample_weight) - inertia = engine.get_score(X, sample_weight) + inertia = engine.get_score(X_prepared, sample_weight_prepared) rtol = 1e-4 if dtype == np.float32 else 1e-6 assert_allclose(inertia, expected, rtol=rtol) @@ -336,8 +336,10 @@ def _get_score_with_centers(centers): kmeans.set_params(random_state=random_state) engine = KMeansEngine(kmeans) - engine.prepare_fit(X) - engine_kmeans_plusplus_centers = engine.init_centroids(X) + X_prepared, *_ = engine.prepare_fit(X) + engine_kmeans_plusplus_centers = engine.init_centroids(X_prepared) + engine_kmeans_plusplus_centers = dpt.asnumpy(engine_kmeans_plusplus_centers.T) + engine.unshift_centers(X_prepared, engine_kmeans_plusplus_centers) scores_engine_kmeans_plusplus.append( _get_score_with_centers(engine_kmeans_plusplus_centers) ) @@ -350,9 +352,9 @@ def _get_score_with_centers(centers): # loop. E.g., for 200 iterations with dtype float32: # # [ - # -1786.160542907715, # np.mean(scores_random_init) - # -886.2205599975586, # np.mean(scores_vanilla_kmeans_plusplus) - # -876.5628140258789, # np.mean(scores_engine_kmeans_plusplus) + # -1786.16057, # np.mean(scores_random_init) + # -886.220595, # np.mean(scores_vanilla_kmeans_plusplus) + # -876.56282806, # np.mean(scores_engine_kmeans_plusplus) # ] assert_allclose( @@ -364,7 +366,7 @@ def _get_score_with_centers(centers): [ -1827.22702, -1027.674243, - -865.257397, + -865.257501, ], ) @@ -380,13 +382,15 @@ def test_kmeans_plusplus_output(dtype): sample_weight = default_rng(random_state).random(X.shape[0], dtype=dtype) estimator = KMeans( - init="k-means++", - n_clusters=n_clusters_sklearn_test, - random_state=random_state, + init="k-means++", n_clusters=n_clusters_sklearn_test, random_state=random_state ) engine = KMeansEngine(estimator) - engine.prepare_fit(X, sample_weight=sample_weight) - centers, indices = engine._kmeans_plusplus(X) + X_prepared, *_ = engine.prepare_fit(X, sample_weight=sample_weight) + + centers, indices = engine._kmeans_plusplus(X_prepared) + centers = dpt.asnumpy(centers.T) + engine.unshift_centers(X_prepared, centers) + indices = dpt.asnumpy(indices) # Check there are the correct number of indices and that all indices are # positive and within the number of samples @@ -414,15 +418,16 @@ def test_kmeans_plusplus_dataorder(): init="k-means++", n_clusters=n_clusters_sklearn_test, random_state=random_state ) engine = KMeansEngine(estimator) - engine.prepare_fit(X_sklearn_test) - centers_c = engine.init_centroids(X_sklearn_test) + X_sklearn_test_prepared, *_ = engine.prepare_fit(X_sklearn_test) + centers_c = engine.init_centroids(X_sklearn_test_prepared) + centers_c = dpt.asnumpy(centers_c.T) X_fortran = np.asfortranarray(X_sklearn_test) - # The engine is re-created to reset random state engine = KMeansEngine(estimator) - engine.prepare_fit(X_sklearn_test) - centers_fortran = engine.init_centroids(X_fortran) + X_fortran_prepared, *_ = engine.prepare_fit(X_fortran) + centers_fortran = engine.init_centroids(X_fortran_prepared) + centers_fortran = dpt.asnumpy(centers_fortran.T) assert_allclose(centers_c, centers_fortran) diff --git a/sklearn_numba_dpex/testing/config.py b/sklearn_numba_dpex/testing/config.py index e69f483..efa235a 100644 --- a/sklearn_numba_dpex/testing/config.py +++ b/sklearn_numba_dpex/testing/config.py @@ -31,14 +31,15 @@ def override_attr_context(obj, **attrs): overriden. The initial values are restored when exiting the context. Trying to override attributes that don't exist will result in an AttributeError""" - - attrs_before = dict() - for attr_name, attr_value in attrs.items(): - # raise AttributeError if obj does not have the attribute attr_name - attrs_before[attr_name] = getattr(obj, attr_name) - setattr(obj, attr_name, attr_value) - - yield - - for attr_name, attr_value in attrs_before.items(): - setattr(obj, attr_name, attr_value) + try: + attrs_before = dict() + for attr_name, attr_value in attrs.items(): + # raise AttributeError if obj does not have the attribute attr_name + attrs_before[attr_name] = getattr(obj, attr_name) + setattr(obj, attr_name, attr_value) + + yield + + finally: + for attr_name, attr_value in attrs_before.items(): + setattr(obj, attr_name, attr_value) From c21dcd2573a9a12af3daa6fa673d1d9f1b40515d Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Thu, 24 Nov 2022 12:44:32 +0100 Subject: [PATCH 2/7] Add tests with dpnp and dpt.usm_ndarray inputs --- .../kmeans/tests/test_kmeans.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py index 57dafd5..f76322b 100644 --- a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py +++ b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py @@ -38,11 +38,17 @@ def test_dpnp_implements_argpartition(): ) +@pytest.mark.parametrize( + "array_constr", + [np.asarray, dpt.asarray, dpnp.asarray], + ids=["numpy", "dpctl", "dpnp"], +) @pytest.mark.parametrize("dtype", float_dtype_params) -def test_kmeans_same_results(dtype): +def test_kmeans_same_results(dtype, array_constr): random_seed = 42 X, _ = make_blobs(random_state=random_seed) X = X.astype(dtype) + X_array = array_constr(X, dtype=dtype) kmeans_vanilla = KMeans( random_state=random_seed, algorithm="lloyd", max_iter=1, init="random" @@ -53,7 +59,7 @@ def test_kmeans_same_results(dtype): kmeans_vanilla.fit(X) with config_context(engine_provider="sklearn_numba_dpex"): - kmeans_engine.fit(X) + kmeans_engine.fit(X_array) # ensure same results assert_array_equal(kmeans_vanilla.labels_, kmeans_engine.labels_) @@ -63,7 +69,7 @@ def test_kmeans_same_results(dtype): # test fit_predict y_labels = kmeans_vanilla.fit_predict(X) with config_context(engine_provider="sklearn_numba_dpex"): - y_labels_engine = kmeans_engine.fit_predict(X) + y_labels_engine = kmeans_engine.fit_predict(X_array) assert_array_equal(y_labels, y_labels_engine) assert_array_equal(kmeans_vanilla.labels_, kmeans_engine.labels_) assert_allclose(kmeans_vanilla.cluster_centers_, kmeans_engine.cluster_centers_) @@ -72,7 +78,7 @@ def test_kmeans_same_results(dtype): # test fit_transform y_transform = kmeans_vanilla.fit_transform(X) with config_context(engine_provider="sklearn_numba_dpex"): - y_transform_engine = kmeans_engine.fit_transform(X) + y_transform_engine = kmeans_engine.fit_transform(X_array) assert_allclose(y_transform, y_transform_engine) assert_array_equal(kmeans_vanilla.labels_, kmeans_engine.labels_) assert_allclose(kmeans_vanilla.cluster_centers_, kmeans_engine.cluster_centers_) @@ -81,19 +87,19 @@ def test_kmeans_same_results(dtype): # # test predict method (returns labels) y_labels = kmeans_vanilla.predict(X) with config_context(engine_provider="sklearn_numba_dpex"): - y_labels_engine = kmeans_engine.predict(X) + y_labels_engine = kmeans_engine.predict(X_array) assert_array_equal(y_labels, y_labels_engine) # test score method (returns negative inertia for each sample) y_scores = kmeans_vanilla.score(X) with config_context(engine_provider="sklearn_numba_dpex"): - y_scores_engine = kmeans_engine.score(X) + y_scores_engine = kmeans_engine.score(X_array) assert_allclose(y_scores, y_scores_engine) # test transform method (returns euclidean distances) y_transform = kmeans_vanilla.transform(X) with config_context(engine_provider="sklearn_numba_dpex"): - y_transform_engine = kmeans_engine.transform(X) + y_transform_engine = kmeans_engine.transform(X_array) assert_allclose(y_transform, y_transform_engine) @@ -372,12 +378,17 @@ def _get_score_with_centers(centers): @pytest.mark.parametrize("dtype", float_dtype_params) -def test_kmeans_plusplus_output(dtype): +@pytest.mark.parametrize( + "array_constr", + [np.asarray, dpt.asarray, dpnp.asarray], + ids=["numpy", "dpctl", "dpnp"], +) +def test_kmeans_plusplus_output(array_constr, dtype): """Test adapted from sklearn's test_kmeans_plusplus_output""" random_state = 42 # Check for the correct number of seeds and all positive values - X = X_sklearn_test.astype(dtype) + X = array_constr(X_sklearn_test, dtype=dtype) sample_weight = default_rng(random_state).random(X.shape[0], dtype=dtype) @@ -400,13 +411,13 @@ def test_kmeans_plusplus_output(dtype): # Check for the correct number of seeds and that they are bound by the data assert centers.shape[0] == n_clusters_sklearn_test - assert (centers.max(axis=0) <= X.max(axis=0)).all() - assert (centers.min(axis=0) >= X.min(axis=0)).all() + assert (centers.max(axis=0) <= X_sklearn_test.max(axis=0)).all() + assert (centers.min(axis=0) >= X_sklearn_test.min(axis=0)).all() # NB: dtype can change depending on the device, so we accept all valid dtypes. assert centers.dtype.type in {np.float32, np.float64} # Check that indices correspond to reported centers - assert_allclose(X[indices].astype(dtype), centers) + assert_allclose(X_sklearn_test[indices].astype(dtype), centers) def test_kmeans_plusplus_dataorder(): From 03f97b253b9bb40e51946838ac8d7104b87057bd Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Thu, 24 Nov 2022 15:07:47 +0100 Subject: [PATCH 3/7] Clarity, commenting, fix caching with lambda functions --- sklearn_numba_dpex/common/kernels.py | 38 +++---- sklearn_numba_dpex/kmeans/drivers.py | 152 +++++++++++++++------------ sklearn_numba_dpex/kmeans/engine.py | 29 ++--- 3 files changed, 116 insertions(+), 103 deletions(-) diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index b0ab71e..8c0e390 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -102,22 +102,14 @@ def broadcast_division(dividend_array, divisor_vector): @lru_cache def make_broadcast_ops_1d_2d_axis1_kernel(size0, size1, ops, work_group_size): - global_size = math.ceil(size1 / work_group_size) * work_group_size - - if ops == "plus": - - @dpex.func - def ops(augend, addend): - return augend + addend - - elif ops == "minus": - - @dpex.func - def ops(minuend, subtrahend): - return minuend - subtrahend + """ + ops must be a function that will be interpreted as a dpex.func and is subject to + the same rules. It is expected to take two scalar arguments and return one scalar + value. lambda functions are advised against since the cache will not work with lamda + functions.""" - else: - raise ValueError(f"Invalid ops: {ops} .") + global_size = math.ceil(size1 / work_group_size) * work_group_size + ops = dpex.func(ops) # NB: inplace. # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple @@ -179,14 +171,13 @@ def _make_partial_sum_reduction_2d_axis1_kernel( minus_one_idx = np.int64(-1) two_as_a_long = np.int64(2) - if fused_unary_func is not None: - fused_unary_func = dpex.func(fused_unary_func) - else: + if fused_unary_func is None: - @dpex.func def fused_unary_func(x): return x + fused_unary_func = dpex.func(fused_unary_func) + # TODO: this set of kernel functions could be abstracted away to other coalescing # functions if n_rows is None: # 1d @@ -341,6 +332,12 @@ def make_sum_reduction_2d_axis1_kernel( `2 * work_group_size`. This is repeated as many time as needed until only one value remains in global memory. + If fused_unary_func is not None, it will be applied element-wise before summing. + It must be a function that will be interpreted as a dpex.func and is subject to the + same rules. It is expected to take one scalar argument and returning one scalar + value. lambda functions are advised against since the cache will not work with + lambda functions. + Notes ----- `work_group_size` is assumed to be a power of 2. @@ -356,9 +353,12 @@ def make_sum_reduction_2d_axis1_kernel( n_rows = size0 if size1 is not None else None sum_axis_size = size0 if n_rows is None else size1 + # fused_unary_func is applied elementwise during the first pass on data, in the + # first kernel execution only. fused_func_kernel = _make_partial_sum_reduction_2d_axis1_kernel( n_rows, work_group_size, fused_unary_func, dtype ) + # subsequent kernel calls only sum the data. nofunc_kernel = _make_partial_sum_reduction_2d_axis1_kernel( n_rows, work_group_size, None, dtype ) diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 7c7c940..2f62b3d 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -328,76 +328,6 @@ def lloyd( return assignments_idx, inertia, centroids_t, n_iteration -def prepare_data_for_lloyd(X_t, init, tol, copy_x): - n_features, n_samples = X_t.shape - compute_dtype = X_t.dtype.type - - device = X_t.device.sycl_device - max_work_group_size = device.max_work_group_size - - sum_axis1_kernel = make_sum_reduction_2d_axis1_kernel( - X_t.shape[0], - X_t.shape[1], - device.max_work_group_size, - device=device, - dtype=compute_dtype, - ) - - X_mean = (sum_axis1_kernel(X_t) / compute_dtype(n_samples))[:, 0] - - if (X_mean == 0).astype(int).sum() == len(X_mean): - X_mean = None - else: - X_t = dpt.asarray(X_t, copy=copy_x) - broadcast_X_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( - n_features, - n_samples, - ops="minus", - work_group_size=max_work_group_size, - ) - - broadcast_X_minus_X_mean(X_t, X_mean) - - if isinstance(init, dpt.usm_ndarray): - n_clusters = init.shape[1] - broadcast_init_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( - n_features, - n_clusters, - ops="minus", - work_group_size=max_work_group_size, - ) - broadcast_init_minus_X_mean(init, X_mean) - - variance_kernel = make_sum_reduction_2d_axis1_kernel( - n_features * n_samples, - None, - max_work_group_size, - device=device, - dtype=compute_dtype, - fused_unary_func=lambda x: x * x, - ) - variance = variance_kernel(dpt.reshape(X_t, -1)) / n_features - tol = variance * tol - - return X_t, X_mean, init, tol - - -def restore_data_after_lloyd(X_t, X_mean): - n_features, n_samples = X_t.shape - - device = X_t.device.sycl_device - max_work_group_size = device.max_work_group_size - - X_t = dpt.asarray(X_t, copy=False) - broadcast_X_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( - n_features, - n_samples, - ops="plus", - work_group_size=max_work_group_size, - ) - broadcast_X_plus_X_mean(X_t, X_mean) - - def _relocate_empty_clusters( n_empty_clusters, X_t, @@ -472,6 +402,88 @@ def _relocate_empty_clusters( ) +def prepare_data_for_lloyd(X_t, init, tol, copy_x): + n_features, n_samples = X_t.shape + compute_dtype = X_t.dtype.type + + device = X_t.device.sycl_device + max_work_group_size = device.max_work_group_size + + sum_axis1_kernel = make_sum_reduction_2d_axis1_kernel( + X_t.shape[0], + X_t.shape[1], + device.max_work_group_size, + device=device, + dtype=compute_dtype, + ) + + X_mean = (sum_axis1_kernel(X_t) / compute_dtype(n_samples))[:, 0] + + if (X_mean == 0).astype(int).sum() == len(X_mean): + X_mean = None + else: + X_t = dpt.asarray(X_t, copy=copy_x) + broadcast_X_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_samples, + ops=_minus, + work_group_size=max_work_group_size, + ) + + broadcast_X_minus_X_mean(X_t, X_mean) + + if isinstance(init, dpt.usm_ndarray): + n_clusters = init.shape[1] + broadcast_init_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_clusters, + ops=_minus, + work_group_size=max_work_group_size, + ) + broadcast_init_minus_X_mean(init, X_mean) + + variance_kernel = make_sum_reduction_2d_axis1_kernel( + n_features * n_samples, + None, + max_work_group_size, + device=device, + dtype=compute_dtype, + fused_unary_func=_square, + ) + variance = variance_kernel(dpt.reshape(X_t, -1)) / n_features + tol = variance * tol + + return X_t, X_mean, init, tol + + +def _square(x): + return x * x + + +def restore_data_after_lloyd(X_t, X_mean): + n_features, n_samples = X_t.shape + + device = X_t.device.sycl_device + max_work_group_size = device.max_work_group_size + + X_t = dpt.asarray(X_t, copy=False) + broadcast_X_plus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( + n_features, + n_samples, + ops=_plus, + work_group_size=max_work_group_size, + ) + broadcast_X_plus_X_mean(X_t, X_mean) + + +def _minus(x, y): + return x - y + + +def _plus(x, y): + return x + y + + def get_labels_inertia(X_t, centroids_t, sample_weight, with_inertia): compute_dtype = X_t.dtype.type n_features, n_samples = X_t.shape diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index f8c381c..1a0030c 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -62,6 +62,11 @@ class KMeansEngine(KMeansCythonEngine): """ + # This class attribute can alter globally the attributes `device` and `order` of + # future instances. It is only used for testing purposes, using + # `sklearn_numba_dpex.testing.config.override_attr_context` context, for instance + # in the benchmark script. + # For normal usage, the compute will follow the __compute_follow_data__ principle. _CONFIG = dict() def __init__(self, estimator): @@ -96,8 +101,6 @@ def prepare_fit(self, X, y=None, sample_weight=None): f"The sklearn_nunmba_dpex engine for KMeans only support the Lloyd algorithm, {algorithm} is not supported." ) - # NB: self.estimator.copy_x is enforced later on only if X_mean has at - # least one non-null value X = self._validate_data(X) estimator._check_params_vs_input(X) @@ -132,7 +135,7 @@ def unshift_centers(self, X, best_centers): # but at the moment it is unknown what those assumptions could be. # As a result, the following instructions are ran every time, even if it # isn't useful when a copy has been made. - # TODO: is there a set of assumptions that exhaustively describe the set + # TODO: is there a set of assumptions that exhaustively describes the set # of accepted inputs, and also enables checking if a copy happened or not # in a simple way ? if not self.estimator.copy_x: @@ -154,7 +157,7 @@ def init_centroids(self, X): else: # NB: sampling without replacement must be executed sequentially so - # it's better adapted to CPU + # it's better done on CPU centers_idx = self.random_state.choice( X.shape[0], size=n_clusters, replace=False ) @@ -255,7 +258,7 @@ def get_euclidean_distances(self, X): def _validate_data(self, X, reset=True): accepted_dtypes = [np.float32] # NB: one could argue that `float32` is a better default, but sklearn defaults - # to `np.float64` and we is apply the same for consistence. + # to `np.float64` and we apply the same for consistency. if self.device.has_aspect_fp64: accepted_dtypes = [np.float64, np.float32] else: @@ -278,11 +281,11 @@ def _validate_data(self, X, reset=True): if "A sparse matrix was passed, but dense data is required" in str( type_error ): - raise NotSupportedByEngineError from TypeError + raise NotSupportedByEngineError from type_error def _check_sample_weight(self, sample_weight, X): """Adapted from sklearn.utils.validation._check_sample_weight to be compatible - with Array API dispatching""" + with Array API dispatch""" n_samples = X.shape[0] dtype = X.dtype if sample_weight is None: @@ -312,7 +315,6 @@ def _check_sample_weight(self, sample_weight, X): sample_weight.shape, (n_samples,) ) ) - sample_weight = dpt.asarray(sample_weight, device=self.device) return sample_weight @@ -330,7 +332,7 @@ def _check_init(self, init, X, copy=False): input_name="init", ) self.estimator._validate_center_shape(X, init) - init_t = dpt.asarray(init.T, order="C", copy=False) + init_t = dpt.asarray(init.T, order="C", copy=False, device=self.device) return init_t @@ -338,12 +340,11 @@ def _get_namespace(*arrays): return dpt, True -def _asarray_with_order(array, dtype, order, copy=None, xp=None): - return dpt.asarray(array, dtype=dtype, order=order, copy=copy) - - @contextlib.contextmanager -def _validate_with_array_api(): +def _validate_with_array_api(device): + def _asarray_with_order(array, dtype, order, copy=None, xp=None): + return dpt.asarray(array, dtype=dtype, order=order, copy=copy, device=device) + # TODO: when https://github.com/IntelPython/dpctl/issues/997 and # https://github.com/scikit-learn/scikit-learn/issues/25000 and are solved # remove those hacks. From d11d91f04772d158b36c4c53e6bbdc72f5ea855d Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Fri, 25 Nov 2022 14:33:22 +0100 Subject: [PATCH 4/7] Clarity and commenting. Co-authored-by: Julien Jerphanion --- sklearn_numba_dpex/common/_utils.py | 12 ++++++++++ sklearn_numba_dpex/common/kernels.py | 12 ++++++---- sklearn_numba_dpex/kmeans/drivers.py | 24 +++++++------------ .../kmeans/tests/test_kmeans.py | 8 +++---- 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/sklearn_numba_dpex/common/_utils.py b/sklearn_numba_dpex/common/_utils.py index 01a6eaf..241f250 100644 --- a/sklearn_numba_dpex/common/_utils.py +++ b/sklearn_numba_dpex/common/_utils.py @@ -5,3 +5,15 @@ def check_power_of_2(e): if e != 2 ** (math.log2(e)): raise ValueError(f"Expected a power of 2, got {e}") return e + + +def _square(x): + return x * x + + +def _minus(x, y): + return x - y + + +def _plus(x, y): + return x + y diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index 8c0e390..ee1e582 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -81,7 +81,8 @@ def initialize_to_zeros(data): def make_broadcast_division_1d_2d_axis0_kernel(size0, size1, work_group_size): global_size = math.ceil(size1 / work_group_size) * work_group_size - # NB: inplace. # Optimized for C-contiguous array and for + # NB: the left operand is modified inplace, the right operand is only read into. + # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple @dpex.kernel def broadcast_division(dividend_array, divisor_vector): @@ -106,12 +107,14 @@ def make_broadcast_ops_1d_2d_axis1_kernel(size0, size1, ops, work_group_size): ops must be a function that will be interpreted as a dpex.func and is subject to the same rules. It is expected to take two scalar arguments and return one scalar value. lambda functions are advised against since the cache will not work with lamda - functions.""" + functions. sklearn_numba_dpex.common._utils expose some pre-defined `ops`. + """ global_size = math.ceil(size1 / work_group_size) * work_group_size ops = dpex.func(ops) - # NB: inplace. # Optimized for C-contiguous array and for + # NB: the left operand is modified inplace, the right operand is only read into. + # Optimized for C-contiguous array and for # size1 >> preferred_work_group_size_multiple @dpex.kernel def broadcast_ops(left_operand_array, right_operand_vector): @@ -336,7 +339,8 @@ def make_sum_reduction_2d_axis1_kernel( It must be a function that will be interpreted as a dpex.func and is subject to the same rules. It is expected to take one scalar argument and returning one scalar value. lambda functions are advised against since the cache will not work with - lambda functions. + lambda functions. sklearn_numba_dpex.common._utils expose some pre-defined + `fused_unary_funcs`. Notes ----- diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 2f62b3d..92cdd8c 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -35,6 +35,8 @@ make_reduce_centroid_data_kernel, ) +from sklearn_numba_dpex.common._utils import _square, _plus, _minus + def lloyd( X_t, @@ -420,6 +422,10 @@ def prepare_data_for_lloyd(X_t, init, tol, copy_x): X_mean = (sum_axis1_kernel(X_t) / compute_dtype(n_samples))[:, 0] if (X_mean == 0).astype(int).sum() == len(X_mean): + # If the data is already centered, there's no need to perform shift/unshift + # steps. In this case, X_mean is set to None, thus carrying the information + # that the data was already centered, and the shift/unshift steps will be + # skipped. X_mean = None else: X_t = dpt.asarray(X_t, copy=copy_x) @@ -443,9 +449,9 @@ def prepare_data_for_lloyd(X_t, init, tol, copy_x): broadcast_init_minus_X_mean(init, X_mean) variance_kernel = make_sum_reduction_2d_axis1_kernel( - n_features * n_samples, - None, - max_work_group_size, + size0=n_features * n_samples, + size1=None, + work_group_size=max_work_group_size, device=device, dtype=compute_dtype, fused_unary_func=_square, @@ -456,10 +462,6 @@ def prepare_data_for_lloyd(X_t, init, tol, copy_x): return X_t, X_mean, init, tol -def _square(x): - return x * x - - def restore_data_after_lloyd(X_t, X_mean): n_features, n_samples = X_t.shape @@ -476,14 +478,6 @@ def restore_data_after_lloyd(X_t, X_mean): broadcast_X_plus_X_mean(X_t, X_mean) -def _minus(x, y): - return x - y - - -def _plus(x, y): - return x + y - - def get_labels_inertia(X_t, centroids_t, sample_weight, with_inertia): compute_dtype = X_t.dtype.type n_features, n_samples = X_t.shape diff --git a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py index f76322b..cb822fb 100644 --- a/sklearn_numba_dpex/kmeans/tests/test_kmeans.py +++ b/sklearn_numba_dpex/kmeans/tests/test_kmeans.py @@ -343,8 +343,8 @@ def _get_score_with_centers(centers): kmeans.set_params(random_state=random_state) engine = KMeansEngine(kmeans) X_prepared, *_ = engine.prepare_fit(X) - engine_kmeans_plusplus_centers = engine.init_centroids(X_prepared) - engine_kmeans_plusplus_centers = dpt.asnumpy(engine_kmeans_plusplus_centers.T) + engine_kmeans_plusplus_centers_t = engine.init_centroids(X_prepared) + engine_kmeans_plusplus_centers = dpt.asnumpy(engine_kmeans_plusplus_centers_t.T) engine.unshift_centers(X_prepared, engine_kmeans_plusplus_centers) scores_engine_kmeans_plusplus.append( _get_score_with_centers(engine_kmeans_plusplus_centers) @@ -398,8 +398,8 @@ def test_kmeans_plusplus_output(array_constr, dtype): engine = KMeansEngine(estimator) X_prepared, *_ = engine.prepare_fit(X, sample_weight=sample_weight) - centers, indices = engine._kmeans_plusplus(X_prepared) - centers = dpt.asnumpy(centers.T) + centers_t, indices = engine._kmeans_plusplus(X_prepared) + centers = dpt.asnumpy(centers_t.T) engine.unshift_centers(X_prepared, centers) indices = dpt.asnumpy(indices) From ac649baf97213f85e0b89009732e179558eeb736 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Fri, 25 Nov 2022 14:46:53 +0100 Subject: [PATCH 5/7] Implement compute follow data --- sklearn_numba_dpex/kmeans/engine.py | 31 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index 12b8e2b..b53dbb5 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -3,6 +3,7 @@ import importlib import numpy as np +import dpnp import dpctl import dpctl.tensor as dpt @@ -26,7 +27,7 @@ ) -class _IgnoreSampleWeight: +class _DeviceUnset: pass @@ -70,7 +71,7 @@ class KMeansEngine(KMeansCythonEngine): _CONFIG = dict() def __init__(self, estimator): - self.device = dpctl.SyclDevice(self._CONFIG.get("device")) + self.device = self._CONFIG.get("device", _DeviceUnset) # NB: numba_dpex kernels only currently supports working with C memory layout # (see https://github.com/IntelPython/numba-dpex/issues/767) but our KMeans @@ -256,15 +257,25 @@ def get_euclidean_distances(self, X): return dpt.asnumpy(euclidean_distances) def _validate_data(self, X, reset=True): + if isinstance(X, dpnp.ndarray): + X = X.get_array() + + if self.device is not _DeviceUnset: + device = dpctl.SyclDevice(self.device) + elif isinstance(X, dpt.usm_ndarray): + device = X.device.sycl_device + else: + device = dpctl.SyclDevice() + accepted_dtypes = [np.float32] # NB: one could argue that `float32` is a better default, but sklearn defaults # to `np.float64` and we apply the same for consistency. - if self.device.has_aspect_fp64: + if device.has_aspect_fp64: accepted_dtypes = [np.float64, np.float32] else: accepted_dtypes = [np.float32] - with _validate_with_array_api(self.device): + with _validate_with_array_api(device): try: X = self.estimator._validate_data( X, @@ -288,12 +299,13 @@ def _check_sample_weight(self, sample_weight, X): with Array API dispatch""" n_samples = X.shape[0] dtype = X.dtype + device = X.device.sycl_device if sample_weight is None: - sample_weight = dpt.ones(n_samples, dtype=dtype, device=self.device) + sample_weight = dpt.ones(n_samples, dtype=dtype, device=device) elif isinstance(sample_weight, numbers.Number): - sample_weight = dpt.full(n_samples, 1, dtype=dtype, device=self.device) + sample_weight = dpt.full(n_samples, 1, dtype=dtype, device=device) else: - with _validate_with_array_api(self.device): + with _validate_with_array_api(device): sample_weight = check_array( sample_weight, accept_sparse=False, @@ -319,7 +331,8 @@ def _check_sample_weight(self, sample_weight, X): return sample_weight def _check_init(self, init, X, copy=False): - with _validate_with_array_api(self.device): + device = X.device.sycl_device + with _validate_with_array_api(device): init = check_array( init, dtype=X.dtype, @@ -332,7 +345,7 @@ def _check_init(self, init, X, copy=False): input_name="init", ) self.estimator._validate_center_shape(X, init) - init_t = dpt.asarray(init.T, order="C", copy=False, device=self.device) + init_t = dpt.asarray(init.T, order="C", copy=False, device=device) return init_t From c1698bddb18912fd85585cbbedd4e3370dd814c4 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Tue, 29 Nov 2022 12:00:09 +0100 Subject: [PATCH 6/7] Clarity and commenting. Co-authored-by: Julien Jerphanion --- sklearn_numba_dpex/common/kernels.py | 174 +++++++++++++-------------- sklearn_numba_dpex/kmeans/drivers.py | 14 ++- sklearn_numba_dpex/kmeans/engine.py | 14 ++- 3 files changed, 108 insertions(+), 94 deletions(-) diff --git a/sklearn_numba_dpex/common/kernels.py b/sklearn_numba_dpex/common/kernels.py index ee1e582..af0432e 100644 --- a/sklearn_numba_dpex/common/kernels.py +++ b/sklearn_numba_dpex/common/kernels.py @@ -162,6 +162,93 @@ def half_l2_norm( return half_l2_norm[global_size, work_group_size] +@lru_cache +def make_sum_reduction_2d_axis1_kernel( + size0, size1, work_group_size, device, dtype, fused_unary_func=None +): + """Implement data_2d.sum(axis=1) or data_1d.sum() + + numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch + a reduction strategy. The strategy relies on the commutativity of the operation used + for the reduction, thus allowing to reduce the input in any order. + + The strategy consists in performing local reductions in each work group using local + memory where each work item combine two values, thus halving the number of values, + and the number of active work items. At each iteration the work items are discarded + in a bracket manner. The work items with the greatest ids are discarded first, and + we rely on the fact that the remaining work items are adjacents to optimize the RW + operations. + + Once the reduction is done in a work group the result is written in global memory, + thus creating an intermediary result whose size is divided by + `2 * work_group_size`. This is repeated as many time as needed until only one value + remains in global memory. + + If fused_unary_func is not None, it will be applied element-wise before summing. + It must be a function that will be interpreted as a dpex.func and is subject to the + same rules. It is expected to take one scalar argument and returning one scalar + value. lambda functions are advised against since the cache will not work with + lambda functions. sklearn_numba_dpex.common._utils expose some pre-defined + `fused_unary_funcs`. + + Notes + ----- + `work_group_size` is assumed to be a power of 2. + + if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None + then the expected shape of input tensors is `(size0, size1)`, and the reduction + operation is equivalent to input.sum(axis=1). In this case, the kernel is a good + choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in + the same order of magnitude than `preferred_work_group_size_multiple`. If not, + other reduction implementations might give better performances. + """ + check_power_of_2(work_group_size) + n_rows = size0 if size1 is not None else None + sum_axis_size = size0 if n_rows is None else size1 + + # fused_unary_func is applied elementwise during the first pass on data, in the + # first kernel execution only. + fused_func_kernel = _make_partial_sum_reduction_2d_axis1_kernel( + n_rows, work_group_size, fused_unary_func, dtype + ) + # subsequent kernel calls only sum the data. + nofunc_kernel = _make_partial_sum_reduction_2d_axis1_kernel( + n_rows, work_group_size, fused_unary_func=None, dtype=dtype + ) + + # As many partial reductions as necessary are chained until only one element + # remains. + kernels_and_empty_tensors_pairs = [] + n_groups = sum_axis_size + # TODO: at some point, the cost of scheduling the kernel is more than the cost of + # running the reduction iteration. At this point the loop should stop and then a + # single work item should iterates one time on the remaining values to finish the + # reduction. + kernel = fused_func_kernel + while n_groups > 1: + n_groups = math.ceil(n_groups / (2 * work_group_size)) + global_size = n_groups * work_group_size + kernel = kernel[global_size, work_group_size] + result_shape = n_groups if n_rows is None else (n_rows, n_groups) + # NB: here memory for partial results is allocated ahead of time and will only + # be garbage collected when the instance of `sum_reduction` is garbage + # collected. Thus it can be more efficient to re-use a same instance of + # `sum_reduction` (e.g within iterations of a loop) since it avoid deallocation + # and reallocation every time. + result = dpt.empty(result_shape, dtype=dtype, device=device) + kernels_and_empty_tensors_pairs.append((kernel, result)) + kernel = nofunc_kernel + + def sum_reduction(summands): + # TODO: manually dispatch the kernels with a SyclQueue + for kernel, result in kernels_and_empty_tensors_pairs: + kernel(summands, result) + summands = result + return result + + return sum_reduction + + @lru_cache def _make_partial_sum_reduction_2d_axis1_kernel( n_rows, work_group_size, fused_unary_func, dtype @@ -313,93 +400,6 @@ def partial_sum_reduction( return partial_sum_reduction -@lru_cache -def make_sum_reduction_2d_axis1_kernel( - size0, size1, work_group_size, device, dtype, fused_unary_func=None -): - """Implement data_2d.sum(axis=1) or data_1d.sum() - - numba_dpex does not provide tools such as `cuda.reduce` so we implement from scratch - a reduction strategy. The strategy relies on the commutativity of the operation used - for the reduction, thus allowing to reduce the input in any order. - - The strategy consists in performing local reductions in each work group using local - memory where each work item combine two values, thus halving the number of values, - and the number of active work items. At each iteration the work items are discarded - in a bracket manner. The work items with the greatest ids are discarded first, and - we rely on the fact that the remaining work items are adjacents to optimize the RW - operations. - - Once the reduction is done in a work group the result is written in global memory, - thus creating an intermediary result whose size is divided by - `2 * work_group_size`. This is repeated as many time as needed until only one value - remains in global memory. - - If fused_unary_func is not None, it will be applied element-wise before summing. - It must be a function that will be interpreted as a dpex.func and is subject to the - same rules. It is expected to take one scalar argument and returning one scalar - value. lambda functions are advised against since the cache will not work with - lambda functions. sklearn_numba_dpex.common._utils expose some pre-defined - `fused_unary_funcs`. - - Notes - ----- - `work_group_size` is assumed to be a power of 2. - - if `size1` is None then the kernel expects 1d tensor inputs. If `size1` is not None - then the expected shape of input tensors is `(size0, size1)`, and the reduction - operation is equivalent to input.sum(axis=1). In this case, the kernel is a good - choice if `size1` >> `preferred_work_group_size_multiple`, and if `size0` ranges in - the same order of magnitude than `preferred_work_group_size_multiple`. If not, - other reduction implementations might give better performances. - """ - check_power_of_2(work_group_size) - n_rows = size0 if size1 is not None else None - sum_axis_size = size0 if n_rows is None else size1 - - # fused_unary_func is applied elementwise during the first pass on data, in the - # first kernel execution only. - fused_func_kernel = _make_partial_sum_reduction_2d_axis1_kernel( - n_rows, work_group_size, fused_unary_func, dtype - ) - # subsequent kernel calls only sum the data. - nofunc_kernel = _make_partial_sum_reduction_2d_axis1_kernel( - n_rows, work_group_size, None, dtype - ) - - # As many partial reductions as necessary are chained until only one element - # remains. - kernels_and_empty_tensors_pairs = [] - n_groups = sum_axis_size - # TODO: at some point, the cost of scheduling the kernel is more than the cost of - # running the reduction iteration. At this point the loop should stop and then a - # single work item should iterates one time on the remaining values to finish the - # reduction. - kernel = fused_func_kernel - while n_groups > 1: - n_groups = math.ceil(n_groups / (2 * work_group_size)) - global_size = n_groups * work_group_size - kernel = kernel[global_size, work_group_size] - result_shape = n_groups if n_rows is None else (n_rows, n_groups) - # NB: here memory for partial results is allocated ahead of time and will only - # be garbage collected when the instance of `sum_reduction` is garbage - # collected. Thus it can be more efficient to re-use a same instance of - # `sum_reduction` (e.g within iterations of a loop) since it avoid deallocation - # and reallocation every time. - result = dpt.empty(result_shape, dtype=dtype, device=device) - kernels_and_empty_tensors_pairs.append((kernel, result)) - kernel = nofunc_kernel - - def sum_reduction(summands): - # TODO: manually dispatch the kernels with a SyclQueue - for kernel, result in kernels_and_empty_tensors_pairs: - kernel(summands, result) - summands = result - return result - - return sum_reduction - - @lru_cache def make_argmin_reduction_1d_kernel(size, work_group_size, device, dtype): """Implement 1d argmin with the same strategy than for make_sum_reduction_2d_axis1_kernel.""" diff --git a/sklearn_numba_dpex/kmeans/drivers.py b/sklearn_numba_dpex/kmeans/drivers.py index 92cdd8c..f9ee12a 100644 --- a/sklearn_numba_dpex/kmeans/drivers.py +++ b/sklearn_numba_dpex/kmeans/drivers.py @@ -405,6 +405,13 @@ def _relocate_empty_clusters( def prepare_data_for_lloyd(X_t, init, tol, copy_x): + """It can be more numerically accurate to center the data first. If copy_x is True, + then the original data is not modified. If False, the original data is modified, + and put back later on (see `restore_data_after_lloyd`), but small numerical + differences may be introduced by subtracting and then adding the data mean. Note + that if the original data is not C-contiguous, a copy will be made even if copy_x + is False.""" + n_features, n_samples = X_t.shape compute_dtype = X_t.dtype.type @@ -420,14 +427,15 @@ def prepare_data_for_lloyd(X_t, init, tol, copy_x): ) X_mean = (sum_axis1_kernel(X_t) / compute_dtype(n_samples))[:, 0] - - if (X_mean == 0).astype(int).sum() == len(X_mean): + X_mean_is_zeroed = (X_mean == 0).astype(int).sum() == len(X_mean) + if X_mean_is_zeroed: # If the data is already centered, there's no need to perform shift/unshift # steps. In this case, X_mean is set to None, thus carrying the information # that the data was already centered, and the shift/unshift steps will be # skipped. X_mean = None else: + # subtract the mean of x for more accurate distance computations X_t = dpt.asarray(X_t, copy=copy_x) broadcast_X_minus_X_mean = make_broadcast_ops_1d_2d_axis1_kernel( n_features, @@ -475,6 +483,8 @@ def restore_data_after_lloyd(X_t, X_mean): ops=_plus, work_group_size=max_work_group_size, ) + # The feature wise mean of X X_mean that had been substracted in + # `prepare_data_for_lloyd` is re-added. broadcast_X_plus_X_mean(X_t, X_mean) diff --git a/sklearn_numba_dpex/kmeans/engine.py b/sklearn_numba_dpex/kmeans/engine.py index b53dbb5..ce094f9 100644 --- a/sklearn_numba_dpex/kmeans/engine.py +++ b/sklearn_numba_dpex/kmeans/engine.py @@ -67,7 +67,7 @@ class KMeansEngine(KMeansCythonEngine): # future instances. It is only used for testing purposes, using # `sklearn_numba_dpex.testing.config.override_attr_context` context, for instance # in the benchmark script. - # For normal usage, the compute will follow the __compute_follow_data__ principle. + # For normal usage, the compute will follow the *compute follows data* principle. _CONFIG = dict() def __init__(self, estimator): @@ -164,6 +164,7 @@ def init_centroids(self, X): ) # Poor man's fancy indexing # TODO: write a kernel ? or replace with better equivalent when available ? + # Relevant issue: https://github.com/IntelPython/dpctl/issues/1003 centers_t = dpt.concat( [dpt.expand_dims(X[center_idx], axes=1) for center_idx in centers_idx], axis=1, @@ -183,9 +184,9 @@ def kmeans_single(self, X, sample_weight, centers_init_t): # ???: using `.all()` often segfaults # TODO: minimal reproducer and issue at dpnp # or write a kernel ? - use_uniform_weights = (sample_weight == sample_weight[0]).astype( - int - ).sum() == len(sample_weight) + use_uniform_weights = ( + (sample_weight == sample_weight[0]).astype(int).sum() + ) == len(sample_weight) assignments_idx, inertia, best_centroids, n_iteration = lloyd( X.T, @@ -221,6 +222,7 @@ def prepare_prediction(self, X, sample_weight): def get_labels(self, X, sample_weight): # TODO: sample_weight actually not used for get_labels. Fix in sklearn ? + # Relevant issue: https://github.com/scikit-learn/scikit-learn/issues/25066 labels, _ = self._get_labels_inertia(X, sample_weight, with_inertia=False) return dpt.asnumpy(labels).astype(np.int32, copy=False) @@ -303,7 +305,9 @@ def _check_sample_weight(self, sample_weight, X): if sample_weight is None: sample_weight = dpt.ones(n_samples, dtype=dtype, device=device) elif isinstance(sample_weight, numbers.Number): - sample_weight = dpt.full(n_samples, 1, dtype=dtype, device=device) + sample_weight = dpt.full( + n_samples, sample_weight, dtype=dtype, device=device + ) else: with _validate_with_array_api(device): sample_weight = check_array( From eb4e3ab7a72a5aa3b09e545af14b9070ff2139f5 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Tue, 29 Nov 2022 13:15:05 +0100 Subject: [PATCH 7/7] linting --- benchmark/ext_helpers/daal4py.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/ext_helpers/daal4py.py b/benchmark/ext_helpers/daal4py.py index fa336bc..9ed0f58 100644 --- a/benchmark/ext_helpers/daal4py.py +++ b/benchmark/ext_helpers/daal4py.py @@ -29,6 +29,7 @@ from sklearn.exceptions import NotSupportedByEngineError from sklearn.cluster._kmeans import KMeansCythonEngine + # TODO: instead of relying on monkey patching the default engine, find a way to # register a distinct entry point that can load a distinct engine outside of setup.py # (impossible ?)