Skip to content

Parallelizing kernel matrix evaluation #3

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 85 additions & 20 deletions kernelmethods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import Iterable
from copy import copy
from itertools import product as iter_product
import os

import numpy as np
from scipy.sparse import issparse, lil_matrix
Expand Down Expand Up @@ -74,6 +75,8 @@ def __str__(self):


# aliasing others to __str__ for now
# TODO having a shorter alias such as p(2)/g(0.1) is conveniennt instead of
# polynomial(degree=2) or gaussian(sigma=0.1): {:s} for __format__()?
def __format__(self, _):
"""Representation"""

Expand Down Expand Up @@ -147,7 +150,8 @@ class KernelMatrix(object):
def __init__(self,
kernel,
normalized=True,
name='KernelMatrix'):
name='KernelMatrix',
n_cpus=None):
"""
Constructor.

Expand All @@ -163,6 +167,14 @@ def __init__(self,
name : str
short name to describe the nature of the kernel function

n_cpus : None or int
Allows to parallelize evaluation of full gram matrix, when the number of
samples is too large to make the full computation of KM too slow, or when the
evaluation of kernel func on a single pair (i,j) is slow (rarely the case).
Default: None, which is serial evaluation of all the pairs.
If a number is specified, n=min(n_cpus, os.cpu_count()) will be used.
If os.cpu_count() is not successful, 2 will be chosen.

"""

if not isinstance(kernel, BaseKernelFunction):
Expand All @@ -184,14 +196,17 @@ def __init__(self,
# user-defined attribute dictionary
self._attr = dict()

self._setup_parallelization(n_cpus)

self._reset()


def attach_to(self,
sample_one,
name_one='sample',
sample_two=None,
name_two=None):
name_two=None,
n_cpus=None):
"""
Attach this kernel to a given sample.

Expand Down Expand Up @@ -461,9 +476,20 @@ def normed_km(self):
return self._normed_km


def _eval_pairs(self, pairs):
"""Helper to facilitate parallel processing in chunks of index pairs"""

print('{} pairs provided: start {}, end: {}'.format(len(pairs)),
pairs[0], pairs[-1])
for idx_one, idx_two in pairs:
self._eval_kernel(idx_one, idx_two)


def _eval_kernel(self, idx_one, idx_two):
"""Returns kernel value between samples identified by indices one and two"""

# print('within eval_kernel : {} {}'.format(idx_one, idx_two))

# maintaining only upper triangular parts, when attached to a single sample
# by ensuring the first index is always <= second index
if idx_one > idx_two and not self._two_samples:
Expand Down Expand Up @@ -581,6 +607,35 @@ def _compute_for_index_combinations(self, set_one, set_two):
dtype=self._sample.dtype).reshape(len(set_one), len(set_two))


def _setup_parallelization(self, n_cpus):
"""Sets the number of CPUs and other state-related flags."""

if n_cpus is not None:
query = os.cpu_count()
if query is None:
query = 2
print('Unable to query the num. CPUs - choosing {}'.format(query))
self._num_cpus = min(int(n_cpus), query)
if self._num_cpus <= 1:
print('num_cpus setup is <=1, skipping parallelization.')
self._parallelize = False
else:
self._parallelize = True
else:
self._parallelize = False
self._num_cpus = None


def _parallel_eval(self):
"""Parallelize the evaluation of KM on subsets of n(n+1)/2 pairs of indices."""

indices = np.dstack(np.triu_indices(self.shape[0], m=self.shape[1])).squeeze()
from multiprocessing import Pool
with Pool(processes=self._num_cpus) as pool:
# n(m+1)/2 into _num_cpus chunks
pool.map(self._eval_pairs, indices, chunksize=0.5*self.size/self._num_cpus)


def _populate_fully(self, dense_fmt=False, fill_lower_tri=False):
"""Applies the kernel function on all pairs of points in a sample.

Expand All @@ -604,37 +659,47 @@ def _populate_fully(self, dense_fmt=False, fill_lower_tri=False):
# kernel matrix is symmetric (in a single sample case)
# so we need only compute half the matrix!
# computing the kernel for diagonal elements i,i as well
# as ix_two, even when equal to ix_one, refers to sample_two in the two_samples case
for ix_one in range(self.shape[0]): # number of rows!
for ix_two in range(ix_one, self.shape[1]): # from second sample!
self._full_km[ix_one, ix_two] = self._eval_kernel(ix_one, ix_two)
# as ix_two, even when equal to ix_one, refers to sample_two in
# the two_samples case
if self._parallelize:
self._parallel_eval()
else:
for ix_one in range(self.shape[0]): # number of rows!
for ix_two in range(ix_one, self.shape[1]): # from second sample!
self._full_km[ix_one, ix_two] = self._eval_kernel(ix_one, ix_two)
except:
raise RuntimeError('Unable to fully compute the kernel matrix!')
else:
self._populated_fully = True

if fill_lower_tri and not self._lower_tri_km_filled:
try:
# choosing k=-1 as main diag is already covered above (nested for loop)
ix_lower_tri = np.tril_indices(self.shape[0], m=self.shape[1], k=-1)

if not self._two_samples and self.shape[0] == self.shape[1]:
self._full_km[ix_lower_tri] = self._full_km.T[ix_lower_tri]
else:
# evaluating it for the lower triangle as well!
for ix_one, ix_two in zip(*ix_lower_tri):
self._full_km[ix_one, ix_two] = self._eval_kernel(ix_one, ix_two)
except:
raise RuntimeError('Unable to symmetrize the kernel matrix!')
else:
self._lower_tri_km_filled = True
self._fill_lower_tri()

if issparse(self._full_km) and dense_fmt:
self._full_km = self._full_km.todense()

return self._full_km


def _fill_lower_tri(self):
"""Helper method to fill the lower tri part of KM"""

try:
# choosing k=-1 as main diag is already covered above (nested for loop)
ix_lower_tri = np.tril_indices(self.shape[0], m=self.shape[1], k=-1)

if not self._two_samples and self.shape[0] == self.shape[1]:
self._full_km[ix_lower_tri] = self._full_km.T[ix_lower_tri]
else:
# evaluating it for the lower triangle as well!
for ix_one, ix_two in zip(*ix_lower_tri):
self._full_km[ix_one, ix_two] = self._eval_kernel(ix_one, ix_two)
except:
raise RuntimeError('Unable to symmetrize the kernel matrix!')
else:
self._lower_tri_km_filled = True


def __str__(self):
"""human readable presentation"""

Expand Down
33 changes: 31 additions & 2 deletions kernelmethods/tests/test_kernel_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
target_labels = np.random.choice(target_label_set, num_samples)

# suffix 1 to indicate one sample case
km1 = KernelMatrix(PolyKernel(degree=2, skip_input_checks=True))
typical_ker_func = PolyKernel(degree=2, skip_input_checks=True)
km1 = KernelMatrix(typical_ker_func)
km1.attach_to(sample_data)

max_num_elements = max_num_ker_eval = num_samples * (num_samples + 1) / 2
Expand Down Expand Up @@ -183,6 +184,33 @@ def test_attach_to_two_samples():
km2.attach_to(sample_data, sample_two=more_dims)


def test_parallelization():


n_cpus = 4
large_n = 1000
large_sample = np.random.rand(large_n, 3)
target_labels = np.random.choice(target_label_set, large_n)

km_parallel = KernelMatrix(typical_ker_func, n_cpus=n_cpus)
km_parallel.attach_to(large_sample)
a = km_parallel.full

from timeit import timeit, repeat
parl = repeat('km_parallel._populate_fully()', repeat=4, globals=locals())

km_serial = KernelMatrix(typical_ker_func)
km_serial.attach_to(large_sample)

serial = repeat('km_serial._populate_fully()', repeat=4, globals=locals())

print('time taken: in paralle : {}, serial: {}'.format(np.median(parl),
np.median(serial)))
if np.median(parl) >= np.median(serial)/(n_cpus-1):
raise ValueError('parallelization with {} cpus has not saved time '
'for sample of size {}'.format(n_cpus, large_n))


def test_attributes():

km = KernelMatrix(LinearKernel())
Expand All @@ -196,4 +224,5 @@ def test_attributes():
assert attr in kma


test_attributes()
# test_attributes()
test_parallelization()