From fb8e544409a2b30058414cf7bfba080a3a972c61 Mon Sep 17 00:00:00 2001 From: joneswong Date: Thu, 8 Dec 2022 18:01:08 +0800 Subject: [PATCH 1/4] added fedentsgd --- .../contrib/trainer/local_entropy.py | 143 ++++++++++++++++++ .../contrib/trainer/local_entropy_trainer.py | 11 ++ federatedscope/core/configs/cfg_training.py | 5 + .../fedentsgd_on_cifar10.yaml | 48 ++++++ 4 files changed, 207 insertions(+) create mode 100644 federatedscope/contrib/trainer/local_entropy.py create mode 100644 federatedscope/contrib/trainer/local_entropy_trainer.py create mode 100644 scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml diff --git a/federatedscope/contrib/trainer/local_entropy.py b/federatedscope/contrib/trainer/local_entropy.py new file mode 100644 index 000000000..b61d70790 --- /dev/null +++ b/federatedscope/contrib/trainer/local_entropy.py @@ -0,0 +1,143 @@ +'''The implementation of ASAM and SAM are borrowed from + https://github.com/debcaldarola/fedsam + Caldarola, D., Caputo, B., & Ciccone, M. + Improving Generalization in Federated Learning by Seeking Flat Minima, + European Conference on Computer Vision (ECCV) 2022. +''' +import math +from collections import defaultdict + +import torch + +from federatedscope.core.trainers import BaseTrainer +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer + + +def copy_params(src): + tgt = dict() + #for name, t in src.state_dict().items(): + for name, t in src.named_parameters(): + if t.requires_grad: + tgt[name] = t.detach().clone() + return tgt + + +def prox_term(cur, last): + loss = .0 + for name, tensor in last.items(): + loss += 0.5 * torch.sum((cur[name] - tensor)**2) + return loss + + +def add_noise(model, sigma): + for p in model.parameters(): + if p.requires_grad: + p.data += sigma * torch.randn(size=p.shape, device=p.device) + + +def moving_avg(cur, new, alpha): + for k, v in cur.items(): + v.data = (1-alpha) * v + alpha * new[k] + + +class LocalEntropyTrainer(BaseTrainer): + def __init__(self, model, data, device, **kwargs): + # NN modules + self.model = model + # FS `ClientData` or your own data + self.data = data + # Device name + self.device = device + # configs + self.kwargs = kwargs + self.config = kwargs['config'] + self.optim_config = self.config.train.optimizer + self.local_entropy_config = self.config.trainer.local_entropy + + def train(self): + # Criterion & Optimizer + criterion = torch.nn.CrossEntropyLoss().to(self.device) + optimizer = get_optimizer(self.model, **self.optim_config) + + # _hook_on_fit_start_init + self.model.to(self.device) + current_global_model = copy_params(self.model) + mu = copy_params(self.model) + self.model.train() + + num_samples, total_loss = self.run_epoch(optimizer, criterion, current_global_model, mu) + for name, param in self.model.named_parameters(): + if name in mu: + param.data = mu[name] + + # _hook_on_fit_end + return num_samples, self.model.cpu().state_dict(), \ + {'loss_total': total_loss, 'avg_loss': total_loss/float( + num_samples)} + + def run_epoch(self, optimizer, criterion, current_global_model, mu): + running_loss = 0.0 + num_samples = 0 + # for inputs, targets in self.trainloader: + for inputs, targets in self.data['train']: + inputs = inputs.to(self.device) + targets = targets.to(self.device) + + # Descent Step + outputs = self.model(inputs) + ce_loss = criterion(outputs, targets) + loss = ce_loss + self.local_entropy_config.gamma * prox_term(self.model.state_dict(), current_global_model) + loss.backward() + optimizer.step() + + # add noise for langevine dynamics + add_noise(self.model, math.sqrt(self.optim_config.lr) * self.local_entropy_config.eps) + + # acc local updates + moving_avg(mu, self.model.state_dict(), self.local_entropy_config.alpha) + + with torch.no_grad(): + running_loss += targets.shape[0] * ce_loss.item() + + num_samples += targets.shape[0] + + return num_samples, running_loss + + def evaluate(self, target_data_split_name='test'): + if target_data_split_name != 'test': + return {} + + with torch.no_grad(): + criterion = torch.nn.CrossEntropyLoss().to(self.device) + + self.model.to(self.device) + self.model.eval() + total_loss = num_samples = num_corrects = 0 + # _hook_on_batch_start_init + for x, y in self.data[target_data_split_name]: + # _hook_on_batch_forward + x, y = x.to(self.device), y.to(self.device) + pred = self.model(x) + loss = criterion(pred, y) + cor = torch.sum(torch.argmax(pred, dim=-1).eq(y)) + + # _hook_on_batch_end + total_loss += loss.item() * y.shape[0] + num_samples += y.shape[0] + num_corrects += cor.item() + + # _hook_on_fit_end + return { + f'{target_data_split_name}_acc': float(num_corrects) / + float(num_samples), + f'{target_data_split_name}_loss': total_loss, + f'{target_data_split_name}_total': num_samples, + f'{target_data_split_name}_avg_loss': total_loss / + float(num_samples) + } + + def update(self, model_parameters, strict=False): + self.model.load_state_dict(model_parameters, strict) + + def get_model_para(self): + return self.model.cpu().state_dict() diff --git a/federatedscope/contrib/trainer/local_entropy_trainer.py b/federatedscope/contrib/trainer/local_entropy_trainer.py new file mode 100644 index 000000000..0f2cbb3c9 --- /dev/null +++ b/federatedscope/contrib/trainer/local_entropy_trainer.py @@ -0,0 +1,11 @@ +from federatedscope.register import register_trainer +from federatedscope.core.trainers import BaseTrainer + + +def call_local_entropy_trainer(trainer_type): + if trainer_type == 'local_entropy_trainer': + from federatedscope.contrib.trainer.local_entropy import LocalEntropyTrainer + return LocalEntropyTrainer + + +register_trainer('local_entropy_trainer', call_local_entropy_trainer) diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index 7e9394ab6..188c2b485 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -15,6 +15,11 @@ def extend_training_cfg(cfg): cfg.trainer.sam.rho = 1.0 cfg.trainer.sam.eta = .0 + cfg.trainer.local_entropy = CN() + cfg.trainer.local_entropy.gamma = 1e-3 + cfg.trainer.local_entropy.eps = 1e-4 + cfg.trainer.local_entropy.alpha = 0.9 + # ---------------------------------------------------------------------- # # Training related options # ---------------------------------------------------------------------- # diff --git a/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml b/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml new file mode 100644 index 000000000..00e943685 --- /dev/null +++ b/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml @@ -0,0 +1,48 @@ +use_gpu: True +device: 0 +early_stop: + patience: 0 +federate: + mode: standalone + total_round_num: 10000 + client_num: 100 + sample_client_num: 5 + make_global_eval: True + merge_test_data: True +fedopt: + use: True + optimizer: + lr: 0.001 + weight_decay: 0.0 + momentum: 0.0 +data: + root: data/ + type: 'CIFAR10@torchvision' + splits: [1.0,0.0,0.0] + num_workers: 0 + transform: [['RandomCrop', {'size': 32, 'padding': 4}], ['RandomHorizontalFlip'], ['ToTensor'], ['Normalize', {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]}]] + test_transform: [['ToTensor'], ['Normalize', {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.2010]}]] + args: [{'download': True}] + splitter: 'fedsam_cifar10_splitter' + splitter_args: [{'alpha': 0.05}] +dataloader: + batch_size: 64 +model: + type: fedsam_conv2 + out_channels: 10 + dropout: 0.0 +criterion: + type: CrossEntropyLoss +trainer: + type: local_entropy_trainer +train: + batch_or_epoch: 'epoch' + optimizer: + lr: 0.01 + weight_decay: 0.0004 + momentum: 0.0 +eval: + freq: 100 + metrics: ['acc', 'correct'] + best_res_update_round_wise_key: test_loss + count_flops: False From 8a8cc2532a7cc434b531ac1ac2059866e4fb9886 Mon Sep 17 00:00:00 2001 From: joneswong Date: Thu, 8 Dec 2022 19:38:38 +0800 Subject: [PATCH 2/4] eval entropy-sgd on cifar-10 --- .../contrib/trainer/local_entropy.py | 4 +++- federatedscope/core/configs/cfg_training.py | 6 +++--- .../fedentsgd_on_cifar10.yaml | 10 +++++++--- .../fedsam_exp_scripts/hpo_for_fedentsgd.sh | 10 ++++++++++ .../run_fedentsgd_on_cifar10.sh | 18 ++++++++++++++++++ 5 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh create mode 100644 scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh diff --git a/federatedscope/contrib/trainer/local_entropy.py b/federatedscope/contrib/trainer/local_entropy.py index b61d70790..4f6418cc2 100644 --- a/federatedscope/contrib/trainer/local_entropy.py +++ b/federatedscope/contrib/trainer/local_entropy.py @@ -78,6 +78,7 @@ def train(self): def run_epoch(self, optimizer, criterion, current_global_model, mu): running_loss = 0.0 num_samples = 0 + thermal = self.local_entropy_config.gamma # for inputs, targets in self.trainloader: for inputs, targets in self.data['train']: inputs = inputs.to(self.device) @@ -86,7 +87,7 @@ def run_epoch(self, optimizer, criterion, current_global_model, mu): # Descent Step outputs = self.model(inputs) ce_loss = criterion(outputs, targets) - loss = ce_loss + self.local_entropy_config.gamma * prox_term(self.model.state_dict(), current_global_model) + loss = ce_loss + thermal * prox_term(self.model.state_dict(), current_global_model) loss.backward() optimizer.step() @@ -100,6 +101,7 @@ def run_epoch(self, optimizer, criterion, current_global_model, mu): running_loss += targets.shape[0] * ce_loss.item() num_samples += targets.shape[0] + thermal *= 1.001 return num_samples, running_loss diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index 188c2b485..fb27a54b4 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -16,9 +16,9 @@ def extend_training_cfg(cfg): cfg.trainer.sam.eta = .0 cfg.trainer.local_entropy = CN() - cfg.trainer.local_entropy.gamma = 1e-3 - cfg.trainer.local_entropy.eps = 1e-4 - cfg.trainer.local_entropy.alpha = 0.9 + cfg.trainer.local_entropy.gamma = 1e-4 + cfg.trainer.local_entropy.eps = 1e-3 + cfg.trainer.local_entropy.alpha = 0.75 # ---------------------------------------------------------------------- # # Training related options diff --git a/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml b/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml index 00e943685..9570bffd4 100644 --- a/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml +++ b/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml @@ -12,7 +12,7 @@ federate: fedopt: use: True optimizer: - lr: 0.001 + lr: 0.0001 weight_decay: 0.0 momentum: 0.0 data: @@ -35,11 +35,15 @@ criterion: type: CrossEntropyLoss trainer: type: local_entropy_trainer + local_entropy: + gamma: 0.0001 + eps: 0.001 + alpha: 0.75 train: batch_or_epoch: 'epoch' optimizer: - lr: 0.01 - weight_decay: 0.0004 + lr: 0.1 + weight_decay: 0.0 momentum: 0.0 eval: freq: 100 diff --git a/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh b/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh new file mode 100644 index 000000000..0c8dbbcec --- /dev/null +++ b/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh @@ -0,0 +1,10 @@ +set -e + +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 0 1e-4 1e-4 0.1 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 1 1e-4 1e-4 1.0 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 2 1e-4 1e-3 0.1 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 3 1e-4 1e-3 1.0 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 4 1e-3 1e-4 0.1 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 5 1e-3 1e-4 1.0 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 6 1e-3 1e-3 0.1 >/dev/null 2>/dev/null & +bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 7 1e-3 1e-3 1.0 >/dev/null 2>/dev/null & diff --git a/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh b/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh new file mode 100644 index 000000000..7498a3097 --- /dev/null +++ b/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh @@ -0,0 +1,18 @@ +set -e + +lda_alpha=$1 +cudaid=$2 +gamma=$3 +eps=$4 +lr=$5 + +echo $lda_alpha +echo $cudaid +echo $gamma +echo $eps +echo $lr + +for (( i=0; i<5; i++ )) +do + python federatedscope/main.py --cfg scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml seed $i device $cudaid data.splitter_args "[{'alpha': ${lda_alpha}}]" trainer.local_entropy.gamma $gamma fedopt.optimizer.lr $gamma trainer.local_entropy.eps $eps train.optimizer.lr $lr expname fedentsgd_${lda_alpha}_${gamma}_${eps}_${lr}_${i} +done From 801fa32551be17a95591d73f7a0b459bfb052b89 Mon Sep 17 00:00:00 2001 From: joneswong Date: Thu, 8 Dec 2022 19:44:42 +0800 Subject: [PATCH 3/4] fix intendent --- federatedscope/contrib/trainer/local_entropy.py | 17 +++++++++++------ .../contrib/trainer/local_entropy_trainer.py | 3 ++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/federatedscope/contrib/trainer/local_entropy.py b/federatedscope/contrib/trainer/local_entropy.py index 4f6418cc2..25f88f520 100644 --- a/federatedscope/contrib/trainer/local_entropy.py +++ b/federatedscope/contrib/trainer/local_entropy.py @@ -15,7 +15,6 @@ def copy_params(src): tgt = dict() - #for name, t in src.state_dict().items(): for name, t in src.named_parameters(): if t.requires_grad: tgt[name] = t.detach().clone() @@ -37,7 +36,7 @@ def add_noise(model, sigma): def moving_avg(cur, new, alpha): for k, v in cur.items(): - v.data = (1-alpha) * v + alpha * new[k] + v.data = (1 - alpha) * v + alpha * new[k] class LocalEntropyTrainer(BaseTrainer): @@ -65,7 +64,8 @@ def train(self): mu = copy_params(self.model) self.model.train() - num_samples, total_loss = self.run_epoch(optimizer, criterion, current_global_model, mu) + num_samples, total_loss = self.run_epoch(optimizer, criterion, + current_global_model, mu) for name, param in self.model.named_parameters(): if name in mu: param.data = mu[name] @@ -87,15 +87,20 @@ def run_epoch(self, optimizer, criterion, current_global_model, mu): # Descent Step outputs = self.model(inputs) ce_loss = criterion(outputs, targets) - loss = ce_loss + thermal * prox_term(self.model.state_dict(), current_global_model) + loss = ce_loss + thermal * prox_term(self.model.state_dict(), + current_global_model) loss.backward() optimizer.step() # add noise for langevine dynamics - add_noise(self.model, math.sqrt(self.optim_config.lr) * self.local_entropy_config.eps) + add_noise( + self.model, + math.sqrt(self.optim_config.lr) * + self.local_entropy_config.eps) # acc local updates - moving_avg(mu, self.model.state_dict(), self.local_entropy_config.alpha) + moving_avg(mu, self.model.state_dict(), + self.local_entropy_config.alpha) with torch.no_grad(): running_loss += targets.shape[0] * ce_loss.item() diff --git a/federatedscope/contrib/trainer/local_entropy_trainer.py b/federatedscope/contrib/trainer/local_entropy_trainer.py index 0f2cbb3c9..f76181f36 100644 --- a/federatedscope/contrib/trainer/local_entropy_trainer.py +++ b/federatedscope/contrib/trainer/local_entropy_trainer.py @@ -4,7 +4,8 @@ def call_local_entropy_trainer(trainer_type): if trainer_type == 'local_entropy_trainer': - from federatedscope.contrib.trainer.local_entropy import LocalEntropyTrainer + from federatedscope.contrib.trainer.local_entropy \ + import LocalEntropyTrainer return LocalEntropyTrainer From 7bd2e7756136e4cdd4eb3f797474fbe76352ab10 Mon Sep 17 00:00:00 2001 From: joneswong Date: Thu, 8 Dec 2022 20:23:59 +0800 Subject: [PATCH 4/4] removed unrelated copyright --- federatedscope/contrib/trainer/local_entropy.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/federatedscope/contrib/trainer/local_entropy.py b/federatedscope/contrib/trainer/local_entropy.py index 25f88f520..57a75eab8 100644 --- a/federatedscope/contrib/trainer/local_entropy.py +++ b/federatedscope/contrib/trainer/local_entropy.py @@ -1,9 +1,3 @@ -'''The implementation of ASAM and SAM are borrowed from - https://github.com/debcaldarola/fedsam - Caldarola, D., Caputo, B., & Ciccone, M. - Improving Generalization in Federated Learning by Seeking Flat Minima, - European Conference on Computer Vision (ECCV) 2022. -''' import math from collections import defaultdict