From 771de0fa9fa38a1264a34c74f4f65a10c17825ed Mon Sep 17 00:00:00 2001 From: futabato <01futabato10@gmail.com> Date: Fri, 1 Nov 2024 11:39:22 +0900 Subject: [PATCH] :bug: Fix a bug about batch_size. --- src/federatedlearning/client/training.py | 4 ++-- src/federatedlearning/server/inferencing.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/federatedlearning/client/training.py b/src/federatedlearning/client/training.py index 2762d90..80c1614 100644 --- a/src/federatedlearning/client/training.py +++ b/src/federatedlearning/client/training.py @@ -58,12 +58,12 @@ def train_val_test( ) validloader: DataLoader = DataLoader( DatasetSplit(dataset, idxs_val), - batch_size=int(len(idxs_val) / 10), + batch_size=self.cfg.train.local_batch_size, shuffle=False, ) testloader: DataLoader = DataLoader( DatasetSplit(dataset, idxs_test), - batch_size=int(len(idxs_test) / 10), + batch_size=self.cfg.train.local_batch_size, shuffle=False, ) return trainloader, validloader, testloader diff --git a/src/federatedlearning/server/inferencing.py b/src/federatedlearning/server/inferencing.py index 8ca2ae3..3258c9e 100644 --- a/src/federatedlearning/server/inferencing.py +++ b/src/federatedlearning/server/inferencing.py @@ -39,7 +39,7 @@ def inference( criterion = nn.NLLLoss().to(device) # Create DataLoader for the testing set - testloader = DataLoader(test_dataset, batch_size=128, shuffle=False) + testloader = DataLoader(test_dataset, batch_size=cfg.train.local_batch_size, shuffle=False) # Loop through the dataset using DataLoader images: torch.Tensor