Skip to content

feat: support multiple devices for ensemble #160

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,13 @@ def run(self):
"Operating System :: Unix",
"Operating System :: MacOS",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10"
],
keywords=["Deep Learning", "PyTorch", "Ensemble Learning"],
packages=find_packages(),
cmdclass=cmdclass,
python_requires=">=3.6",
python_requires=">=3.8",
install_requires=install_requires,
)
91 changes: 52 additions & 39 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import torch.nn as nn

from typing import List, Dict

from . import _constants as const
from .utils.io import split_data_target
from .utils.logging import get_tb_logger
Expand Down Expand Up @@ -60,11 +62,11 @@ class BaseModule(nn.Module):

def __init__(
self,
estimator,
n_estimators,
estimator_args=None,
cuda=True,
n_jobs=None,
estimator: nn.Module,
n_estimators: int,
estimator_args: Dict = None,
device: str | List[str] = "cuda",
n_jobs: int = None,
):
super(BaseModule, self).__init__()
self.base_estimator_ = estimator
Expand All @@ -78,12 +80,33 @@ def __init__(
)
warnings.warn(msg, RuntimeWarning)

self.device = torch.device("cuda" if cuda else "cpu")
# Specify running devices for each estimator
if isinstance(device, str):
self.device = torch.device(device)
elif isinstance(device, list):
if not len(device) == n_estimators:
msg = "The length of `device` list should equal `n_estimators`."
self.logger.error(msg)
raise ValueError(msg)
self.device = device
else:
msg = "The argument `device` should be a string, or a list of string, got {} instead."
self.logger.error(msg.format(type(device)))
raise ValueError(msg.format(type(device)))

self.n_jobs = n_jobs
self.logger = logging.getLogger()
self.tb_logger = get_tb_logger()

self.estimators_ = nn.ModuleList()

self._criterion = None

self.optimizer_name = None
self.optimizer_args = None

self.scheduler_name = None
self.scheduler_args = None
self.use_scheduler_ = False

def __len__(self):
Expand All @@ -102,7 +125,7 @@ def __getitem__(self, index):
def _decide_n_outputs(self, train_loader):
"""Decide the number of outputs according to the `train_loader`."""

def _make_estimator(self):
def _make_estimator(self, idx):
"""Make and configure a copy of `self.base_estimator_`."""

# Call `deepcopy` to make a base estimator
Expand All @@ -117,7 +140,7 @@ def _make_estimator(self):
else:
estimator = self.base_estimator_(**self.estimator_args)

return estimator.to(self.device)
return estimator.to(self.device[idx])

def _validate_parameters(self, epochs, log_interval):
"""Validate hyper-parameters on training the ensemble."""
Expand Down Expand Up @@ -185,9 +208,9 @@ def predict(self, *x):
x_device = []
for data in x:
if isinstance(data, torch.Tensor):
x_device.append(data.to(self.device))
x_device.append(data.to("cpu"))
elif isinstance(data, np.ndarray):
x_device.append(torch.Tensor(data).to(self.device))
x_device.append(torch.Tensor(data).to("cpu"))
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
Expand All @@ -206,23 +229,13 @@ def __init__(
n_estimators=10,
depth=5,
lamda=1e-3,
cuda=False,
device="cuda",
n_jobs=None,
):
super(BaseModule, self).__init__()
self.base_estimator_ = BaseTree
self.n_estimators = n_estimators
super(BaseModule, self).__init__(BaseTree, n_estimators, {}, device, n_jobs)
self.depth = depth
self.lamda = lamda

self.device = torch.device("cuda" if cuda else "cpu")
self.n_jobs = n_jobs
self.logger = logging.getLogger()
self.tb_logger = get_tb_logger()

self.estimators_ = nn.ModuleList()
self.use_scheduler_ = False

def _decidce_n_inputs(self, train_loader):
"""Decide the input dimension according to the `train_loader`."""
for _, elem in enumerate(train_loader):
Expand All @@ -231,7 +244,7 @@ def _decidce_n_inputs(self, train_loader):
data = data.view(n_samples, -1)
return data.size(1)

def _make_estimator(self):
def _make_estimator(self, idx):
"""Make and configure a soft decision tree."""
estimator = BaseTree(
input_dim=self.n_inputs,
Expand All @@ -241,7 +254,7 @@ def _make_estimator(self):
cuda=self.device == torch.device("cuda"),
)

return estimator.to(self.device)
return estimator.to(self.device[idx])


class BaseClassifier(BaseModule):
Expand All @@ -263,7 +276,7 @@ def _decide_n_outputs(self, train_loader):
else:
labels = []
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
_, target = split_data_target(elem, "cpu")
labels.append(target)
labels = torch.unique(torch.cat(labels))
n_outputs = labels.size(0)
Expand All @@ -279,7 +292,7 @@ def evaluate(self, test_loader, return_loss=False):
loss = 0.0

for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
data, target = split_data_target(elem, "cpu")

output = self.forward(*data)

Expand Down Expand Up @@ -371,28 +384,28 @@ def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False):
self.leaf_node_num_, self.output_dim, bias=False
)

def forward(self, X, is_training_data=False):
_mu, _penalty = self._forward(X)
def forward(self, x, is_training_data=False):
_mu, _penalty = self._forward(x)
y_pred = self.leaf_nodes(_mu)

# When `X` is the training data, the model also returns the penalty
# When `x` is the training data, the model also returns the penalty
# to compute the training loss.
if is_training_data:
return y_pred, _penalty
else:
return y_pred

def _forward(self, X):
def _forward(self, x):
"""Implementation on the data forwarding process."""

batch_size = X.size()[0]
X = self._data_augment(X)
batch_size = x.size(0)
x = self._data_augment(x)

path_prob = self.inner_nodes(X)
path_prob = self.inner_nodes(x)
path_prob = torch.unsqueeze(path_prob, dim=2)
path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)

_mu = X.data.new(batch_size, 1, 1).fill_(1.0)
_mu = x.data.new(batch_size, 1, 1).fill_(1.0)
_penalty = torch.tensor(0.0).to(self.device)

# Iterate through internal odes in each layer to compute the final path
Expand Down Expand Up @@ -437,14 +450,14 @@ def _cal_penalty(self, layer_idx, _mu, _path_prob):

return penalty

def _data_augment(self, X):
def _data_augment(self, x):
"""Add a constant input `1` onto the front of each sample."""
batch_size = X.size()[0]
X = X.view(batch_size, -1)
batch_size = x.size(0)
x = x.view(batch_size, -1)
bias = torch.ones(batch_size, 1).to(self.device)
X = torch.cat((bias, X), 1)
x = torch.cat((bias, x), 1)

return X
return x

def _validate_parameters(self):

Expand Down
18 changes: 9 additions & 9 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
The dictionary of hyper-parameters used to instantiate base
estimators. This parameter will have no effect if ``estimator`` is a
base estimator object after instantiation.
cuda : bool, default=True
device : string or List, default='cuda'

- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.
- If :obj:`string`, all base estimators will be running on the device the string specified.
- If :obj:`List`, each base estimator will be running on the device ``device[base_estimator_index]``.
n_jobs : int, default=None
The number of workers for training the ensemble. This input
argument is used for parallel ensemble methods such as
Expand Down Expand Up @@ -46,10 +46,10 @@
The dictionary of hyper-parameters used to instantiate base
estimators. This parameter will have no effect if ``estimator`` is a
base estimator object after instantiation.
cuda : bool, default=True
device : string or List, default='cuda'

- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.
- If :obj:`string`, all base estimators will be running on the device the string specified.
- If :obj:`List`, each base estimator will be running on the device ``device[base_estimator_index]``.

Attributes
----------
Expand All @@ -70,10 +70,10 @@
The coefficient of the regularization term when training neural
trees, proposed in the paper: `Distilling a neural network into a
soft decision tree <https://arxiv.org/abs/1711.09784>`_.
cuda : bool, default=True
device : string or List, default='cuda'

- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.
- If :obj:`string`, all base estimators will be running on the device the string specified.
- If :obj:`List`, each base estimator will be running on the device ``device[base_estimator_index]``.
n_jobs : int, default=None
The number of workers for training the ensemble. This input
argument is used for parallel ensemble methods such as
Expand Down
10 changes: 5 additions & 5 deletions torchensemble/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def load(model, save_dir="./", map_location=None, logger=None):
model.n_inputs = state["n_inputs"]

# Pre-allocate and load all base estimators
for _ in range(n_estimators):
model.estimators_.append(model._make_estimator())
for idx in range(n_estimators):
model.estimators_.append(model._make_estimator(idx))
model.load_state_dict(model_params)


def split_data_target(element, device, logger=None):
def split_data_target(element, device="cpu", logger=None):
"""Split elements in dataloader according to pre-defined rules."""
if not (isinstance(element, list) or isinstance(element, tuple)):
msg = (
Expand All @@ -98,8 +98,8 @@ def split_data_target(element, device, logger=None):
elif len(element) > 2:
# Dataloader with multiple inputs and one target
data, target = element[:-1], element[-1]
data_device = [tensor.to(device) for tensor in data]
return data_device, target.to(device)
data = [tensor.to(device) for tensor in data]
return data, target.to(device)
else:
# Dataloader with invalid input
msg = (
Expand Down
37 changes: 18 additions & 19 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""


import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, voting_strategy="soft", **kwargs):
implemented_strategies = {"soft", "hard"}
if voting_strategy not in implemented_strategies:
msg = (
"Voting strategy {} is not implemented, "
"Voting strategy `{}` is not implemented, "
"please choose from {}."
)
raise ValueError(
Expand All @@ -112,11 +113,10 @@ def __init__(self, voting_strategy="soft", **kwargs):
"classifier_forward",
)
def forward(self, *x):

outputs = [
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
for estimator in self.estimators_
]
outputs = []
for (estimator, device) in zip(self.estimators_, self.device):
per_estimator_x = tuple(map(lambda x: x.to(device), copy.deepcopy(x)))
outputs.append(F.softmax(op.unsqueeze_tensor(estimator(*per_estimator_x)), dim=1).to("cpu"))

if self.voting_strategy == "soft":
proba = op.average(outputs)
Expand Down Expand Up @@ -164,8 +164,8 @@ def fit(

# Instantiate a pool of base estimators, optimizers, and schedulers.
estimators = []
for _ in range(self.n_estimators):
estimators.append(self._make_estimator())
for idx in range(self.n_estimators):
estimators.append(self._make_estimator(idx))

optimizers = []
for i in range(self.n_estimators):
Expand All @@ -180,18 +180,19 @@ def fit(
optimizers[0], self.scheduler_name, **self.scheduler_args
)

# Check the training criterion
if not hasattr(self, "_criterion"):
# Check the training criterion, use the cross-entropy loss by default
if not hasattr(self, "_criterion") or self._criterion is None:
self._criterion = nn.CrossEntropyLoss()

# Utils
best_acc = 0.0

# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
def _forward(estimators, devices, *data):
outputs = []
for (estimator, device) in zip(estimators, devices):
per_estimator_x = tuple(map(lambda x: x.to(device), copy.deepcopy(data)))
outputs.append(F.softmax(estimator(*per_estimator_x), dim=1).to("cpu"))

if self.voting_strategy == "soft":
proba = op.average(outputs)
Expand Down Expand Up @@ -230,7 +231,7 @@ def _forward(estimators, *x):
idx,
epoch,
log_interval,
self.device,
self.device[idx],
True,
)
for idx, (estimator, optimizer) in enumerate(
Expand All @@ -251,10 +252,8 @@ def _forward(estimators, *x):
correct = 0
total = 0
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(
elem, self.device
)
output = _forward(estimators, *data)
data, target = io.split_data_target(elem, "cpu")
output = _forward(estimators, self.device, *data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
Expand Down