diff --git a/.all-contributorsrc b/.all-contributorsrc index b9d5af8..7b043f2 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -70,6 +70,15 @@ "code" ] }, + { + "login": "e-eight", + "name": "Soham Pal", + "avatar_url": "https://avatars.githubusercontent.com/u/3883241?v=4", + "profile": "https://soham.dev", + "contributions": [ + "code" + ] + }, { "login": "by256", "name": "Batuhan Yildirim", diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 02923ef..9f6e690 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -8,6 +8,8 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
Matt Gadd 💻 |
+ Soham Pal 💻 |
Batuhan Yildirim 💻 |
Matt Gadd 💻 |
Wenjie Zhang 💻 ⚠️ |
diff --git a/torchensemble/_constants.py b/torchensemble/_constants.py
index 33db14f..e2c4376 100644
--- a/torchensemble/_constants.py
+++ b/torchensemble/_constants.py
@@ -64,7 +64,7 @@
optimizer_name : string
The name of the optimizer, should be one of {``Adadelta``, ``Adagrad``,
``Adam``, ``AdamW``, ``Adamax``, ``ASGD``, ``RMSprop``, ``Rprop``,
- ``SGD``}.
+ ``SGD``, ``LBFGS``}.
**kwargs : keyword arguments
Keyword arguments on setting the optimizer, should be in the form:
``lr=1e-3, weight_decay=5e-4, ...``. These keyword arguments
diff --git a/torchensemble/bagging.py b/torchensemble/bagging.py
index 7b0f008..26cd94e 100644
--- a/torchensemble/bagging.py
+++ b/torchensemble/bagging.py
@@ -6,19 +6,17 @@
"""
+import warnings
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-import warnings
from joblib import Parallel, delayed
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["BaggingClassifier", "BaggingRegressor"]
@@ -59,11 +57,20 @@ def _parallel_fit_per_epoch(
sampling_data = [tensor[sampling_mask] for tensor in data]
sampling_target = target[sampling_mask]
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ sampling_output = estimator(*sampling_data)
+ loss = criterion(sampling_output, sampling_target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
sampling_output = estimator(*sampling_data)
- loss = criterion(sampling_output, sampling_target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -79,7 +86,12 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
- idx, epoch, batch_idx, loss, correct, subsample_size
+ idx,
+ epoch,
+ batch_idx,
+ loss.item(),
+ correct,
+ subsample_size,
)
)
else:
@@ -87,7 +99,7 @@ def _parallel_fit_per_epoch(
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
- print(msg.format(idx, epoch, batch_idx, loss))
+ print(msg.format(idx, epoch, batch_idx, loss.item()))
return estimator, optimizer
diff --git a/torchensemble/fusion.py b/torchensemble/fusion.py
index 9a5857c..a573968 100644
--- a/torchensemble/fusion.py
+++ b/torchensemble/fusion.py
@@ -10,12 +10,10 @@
import torch.nn as nn
import torch.nn.functional as F
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["FusionClassifier", "FusionRegressor"]
@@ -109,11 +107,20 @@ def fit(
data, target = io.split_data_target(elem, self.device)
batch_size = data[0].size(0)
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = self._forward(*data)
+ loss = self._criterion(output, target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
output = self._forward(*data)
- loss = self._criterion(output, target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -127,12 +134,16 @@ def fit(
)
self.logger.info(
msg.format(
- epoch, batch_idx, loss, correct, batch_size
+ epoch,
+ batch_idx,
+ loss.item(),
+ correct,
+ batch_size,
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
- "fusion/Train_Loss", loss, total_iters
+ "fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1
@@ -257,20 +268,31 @@ def fit(
data, target = io.split_data_target(elem, self.device)
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = self.forward(*data)
+ loss = self._criterion(output, target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
output = self.forward(*data)
- loss = self._criterion(output, target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}"
- self.logger.info(msg.format(epoch, batch_idx, loss))
+ self.logger.info(
+ msg.format(epoch, batch_idx, loss.item())
+ )
if self.tb_logger:
self.tb_logger.add_scalar(
- "fusion/Train_Loss", loss, total_iters
+ "fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1
diff --git a/torchensemble/tests/test_set_optimizer.py b/torchensemble/tests/test_set_optimizer.py
index 393c08f..1991ede 100644
--- a/torchensemble/tests/test_set_optimizer.py
+++ b/torchensemble/tests/test_set_optimizer.py
@@ -1,7 +1,6 @@
import pytest
-import torchensemble
import torch.nn as nn
-
+import torchensemble
optimizer_list = [
"Adadelta",
@@ -13,6 +12,7 @@
"RMSprop",
"Rprop",
"SGD",
+ "LBFGS",
]
@@ -33,9 +33,14 @@ def forward(self, X):
@pytest.mark.parametrize("optimizer_name", optimizer_list)
def test_set_optimizer_normal(optimizer_name):
model = MLP()
- torchensemble.utils.set_module.set_optimizer(
- model, optimizer_name, lr=1e-3
- )
+ if optimizer_name != "LBFGS":
+ torchensemble.utils.set_module.set_optimizer(
+ model, optimizer_name, lr=1e-3
+ )
+ else:
+ torchensemble.utils.set_module.set_optimizer(
+ model, optimizer_name, history_size=7, max_iter=10
+ )
def test_set_optimizer_Unknown():
diff --git a/torchensemble/utils/set_module.py b/torchensemble/utils/set_module.py
index 750ffe8..dcbb312 100644
--- a/torchensemble/utils/set_module.py
+++ b/torchensemble/utils/set_module.py
@@ -18,6 +18,7 @@ def set_optimizer(model, optimizer_name, **kwargs):
"RMSprop",
"Rprop",
"SGD",
+ "LBFGS",
]
if optimizer_name not in torch_optim_optimizers:
msg = "Unrecognized optimizer: {}, should be one of {}."
diff --git a/torchensemble/voting.py b/torchensemble/voting.py
index bdb127a..8296f6c 100644
--- a/torchensemble/voting.py
+++ b/torchensemble/voting.py
@@ -5,19 +5,17 @@
"""
+import warnings
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-import warnings
from joblib import Parallel, delayed
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["VotingClassifier", "VotingRegressor"]
@@ -49,11 +47,20 @@ def _parallel_fit_per_epoch(
data, target = io.split_data_target(elem, device)
batch_size = data[0].size(0)
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = estimator(*data)
+ loss = criterion(output, target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
output = estimator(*data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -69,7 +76,7 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
- idx, epoch, batch_idx, loss, correct, batch_size
+ idx, epoch, batch_idx, loss.item(), correct, batch_size
)
)
# Regression
@@ -78,7 +85,7 @@ def _parallel_fit_per_epoch(
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
- print(msg.format(idx, epoch, batch_idx, loss))
+ print(msg.format(idx, epoch, batch_idx, loss.item()))
return estimator, optimizer