FedSL: Federated Split Learning for Collaborative Healthcare Analytics on Resource-Constrained Wearable IoMT Devices
Note: All open-source code and simulation data are used for the following paper:
Title: FedSL: Federated Split Learning for Collaborative Healthcare Analytics on Resource-Constrained Wearable IoMT Devices
Author: Wanli Ni, Huiqing Ao, Hui Tian, Yonina C. Eldar, and Dusit Niyato
In Figure 1, we show the basic setting of split learning (SL). In this paradigm, an
Figure 1. An illustration of the basic idea of SL using a single device and a server.
We denote the forward propagation functions at the device and server as
After the forward propagation, the server will back propagate the gradients from the output layer of the server-side model to the input layer of the device-side model to update the model weights
In this work, we propose a federated split learning (FedSL) framework to achieve distributed machine learning on multiple resource-constrained devices, as shown in Figure 2. On the device side, each client is responsible for training a small shallow neural network. However, the edge server not only needs to undertake the computing task of the high-level subnetwork, but also needs to aggregate and synchronize these low-level sub-networks.
Figure 2. An illustration of the proposed FedSL framework.
In Figure 3, we illustrate the computation and communication processes of the proposed FedSL framework. Specifically, each training round of the proposed FedSL has two main stages: 1) multi-user split learning, and 2) federated sub-network averaging. In different stages, the data shared between devices and the server are different. In the first stage, devices share the ground-truth label and the smashed data with the server to complete the forward propagation. In addition, the server will transfer the gradients to local devices to complete the backward propagation. In the second stage, devices will share the sub-network parameters with the server to synchronize model parameters.
Figure 3. One training round of the proposed FedSL framework.
As outlined in Fig. 4, FedSL allows multiple IoMT devices to train a shared model in parallel. The edge server starts by sending the device-side model to all the devices. Each device processes its data using this model and sends key information back to the server. The server updates its model based on this data and sends gradients back to all devices to help them improve their local models. This process continues until the model achieves the desired accuracy.
Figure 4. The parallel model training scheme in the FedSL framework.
As shown in Fig. 5, in the sequential FedSL model training process, each device trains independently with the edge server. The process for individual devices is similar to the parallel FedSL approach, but without model aggregation after each update. New devices that join the training later build upon the knowledge gained by earlier participants. Once the last device completes training, one round is complete. This sequential approach offers flexibility in terms of training execution and allows for dynamic participation of devices in the training process.
Figure 5. The sequential model training scheme in the FedSL framework.
In Table 1, we compare the communication overhead per client and total communication overhead across all clients in IoT networks. Assume
Scheme | Training approach | Communication overhead per client | Total communication overhead |
---|---|---|---|
SL |
Sequential | ||
FL |
Parallel | ||
FedSL |
Parallel or sequential |
Table 1. Comparison of communication overhead among SL, FL, and FedSL.
In Table 1 and Table 2, we present a comprehensive comparison among SL, FL, and the proposed FedSL. These two tables highlights the key features and advantages of each approach.
Scheme | # of users | Model aggregation | Applicable to low-end devices | Distributed Computing | Sharing raw data |
---|---|---|---|---|---|
SL |
One | No | Yes | Yes | No |
FL |
Multiple | Yes | No | Yes | No |
FedSL |
Multiple | Yes | Yes | Yes | No |
Table 2. A comprehensive comparison among SL, FL, and FedSL.
Taking ResNet-18 as an example, if we divide the model at the third layer, it requires around 28.85M floating point operations for local computation in one training round. Considering that most wearables can handle up to 1.2G floating point operations per second, our proposed FedSL can leverage the computing power of the edge server to effectively support medical image analysis on resource-constrained IoMT devices.
Parameters | Description | Value |
---|---|---|
Neural network | ResNet-18 | [64 64 64 64 64 128 128 128 128 256 256 256 256 512 512 512 512 2 or 4] |
Dataset 1 | Chest X-Ray | Two categories (PNEUMONIA: 4273, NORMAL: 1583) |
Dataset 2 | Optical Coherence Tomography (OCT) | Four categories (NORMAL: 51390, CNV: 37455, DME: 11598, DRUSEN: 8866) |
Distribution 1 | Independent and identically distributed (IID) | Data distributions between devices are the same |
Distribution 2 | Non-IID | Data distributions between devices are different |
Number of IoMT devices | 5 | |
Index of the cut layer | 3 | |
Maximum number of training rounds | 400 | |
Learning rate | 0.0001 |
No. | Name | Description |
---|---|---|
Scheme 1 |
Centralized learning (CL) | All data samples are sent to the server for model training |
Scheme 2 |
Federated learning (FL) | Devices train their own models without sharing datasets |
Scheme 3 |
Sequential FedSL | In one time slot, only a single device trains the model with the server using the split learning. |
Scheme 4 |
Parallel FedSL | All devices train models simultaneously, and the server computes different models in parallel. |
Performance | Chest X-Ray images with the IID setting | Chest X-Ray images with the non-IID setting |
---|---|---|
Accuracy | ||
Loss |
Figure 4. Learning performance on the chest X-Ray dataset with IID and non-IID settings.
Performance | OCT images with the IID setting | OCT images with the non-IID setting |
---|---|---|
Accuracy | ||
Loss |
Figure 5. Learning performance on the OCT dataset with IID and non-IID settings.
[1] A. Gatouillat, Y. Badr et al., “Internet of Medical Things: A review of recent contributions dealing with cyber-physical systems in medicine,” IEEE Internet of Things J., vol. 5, no. 5, pp. 3810–3822, Oct. 2018.
[2] P. Vepakomma, O. Gupta et al., “Split learning for health: Distributed deep learning without sharing raw patient data,” Dec. 2018. [Online]. Available: https://arxiv.org/pdf/1812.00564.pdf
[3] M. Zhang, L. Qu et al., “SplitAVG: A heterogeneity-aware federated deep learning method for medical imaging,” IEEE J. Biomed. Health Informatics, vol. 26, no. 9, pp. 4635–4644, Sept. 2022.
[4] D. S. Kermany, M. Goldbaum et al., “Identifying medical diagnoses and treatable diseases by image-based deep learning,” Cell, vol. 172, no. 5, pp. 1122–1131.e9, Feb. 2018.
[5] T. Gafni, N. Shlezinger et al., “Federated learning: A signal processing perspective,” IEEE Signal Process. Mag., vol. 39, no. 3, pp. 14–41, May 2022.
[6] L. You et al., “A triple-step asynchronous federated learning mechanism for client activation, interaction optimization, and aggregation enhancement,” IEEE Internet of Things J., vol. 9, no. 23, pp. 24199-24211, Dec. 2022.
[7] C. Thapa et al., “SplitFed: When federated learning meets split learning,” in Proc. AAAI, Feb. 2022, pp. 8485-8493.
code/
├── Chest_XRay/
├── pre_process_chest_xray_dataset.py
├── IID/
├── chest_xray_iid_data_processing.py
├── CL_chest_xray_iid.py
├── FL_chest_xray_iid.py
├── SSL_chest_xray_iid.py
└── FSL_chest_xray_iid.py
└── Non-IID/
├── chest_xray_non_iid_data_processing.py
├── CL_chest_xray_non_iid.py
├── FL_chest_xray_non_iid.py
├── SSL_chest_xray_non_iid.py
└── FSL_chest_xray_non_iid.py
└── OCT/
├── pre_process_OCT_dataset.py
├── IID/
├── OCT_iid_data_processing.py
├── CL_OCT_iid.py
├── FL_OCT_iid.py
├── SSL_OCT_iid.py
└── FSL_OCT_iid.py
└── Non-IID/
├── OCT_non_iid_data_processing.py
├── CL_OCT_non_iid.py
├── FL_OCT_non_iid.py
├── SSL_OCT_non_iid.py
└── FSL_OCT_non_iid.py
simulation_results/
├── performance_on_the_chest_xray_dataset/
├── draw_chest_xray_accuracy/
├── draw_chest_xray_accuracy_iid.m
└── draw_chest_xray_accuracy_non_iid.m
└── draw_chest_xray_loss/
├── draw_chest_xray_loss_iid.m
└── draw_chest_xray_loss_non_iid.m
└── performance_on_the_OCT_dataset/
├── draw_OCT_accuracy/
├── draw_OCT_accuracy_iid.m
└── draw_OCT_accuracy_non_iid.m
└── draw_OCT_loss/
├── draw_OCT_loss_iid.m
└── draw_OCT_loss_non_iid.m