From e532d88bc4924e136810825a095253aad1a0416f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 21 Nov 2024 19:02:16 +0530 Subject: [PATCH 1/7] cleanup fl aggregator --- main.py | 499 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 316 insertions(+), 183 deletions(-) diff --git a/main.py b/main.py index d35cdbe..bd82edd 100644 --- a/main.py +++ b/main.py @@ -51,14 +51,61 @@ def get_client_proj_state(project_folder: Path) -> dict: project_state = {} project_state_file = project_folder / "state/state.json/" - if project_state_file.is_file(): - project_state = json.load(project_state_file.open()) + project_state = read_json(project_state_file) return project_state -def init_fl_aggregator_app(client: Client) -> None: +def read_json(json_file: Path) -> dict: + """ + Reads the json file and returns the contents + """ + with open(json_file, "r") as f: + data = json.load(f) + + return data + + +def write_json(json_file: Path, data: dict) -> None: + """ + Writes the data to the json file + """ + with open(json_file, "w") as f: + json.dump(data, f, indent=4) + + +def validate_launch_config(fl_config: Path) -> bool: + """ + Validates the `fl_config.json` file + """ + + try: + fl_config = read_json(fl_config) + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON format in {fl_config.resolve()}") + + required_keys = [ + "project_name", + "aggregator", + "participants", + "model_arch", + "model_weight", + "model_class_name", + "rounds", + "epoch", + "test_dataset", + "learning_rate", + ] + + for key in required_keys: + if key not in fl_config: + raise ValueError(f"Required key {key} is missing in fl_config.json") + + return True + + +def init_aggregator(client: Client) -> None: """ Creates the `fl_aggregator` app in the `api_data` folder with the following structure: @@ -82,9 +129,35 @@ def init_fl_aggregator_app(client: Client) -> None: app_pvt_dir.mkdir(parents=True, exist_ok=True) -def initialize_fl_project(client: Client, fl_config_json_path: Path) -> None: +def create_metrics_dashboard( + client: Client, fl_config: dict, participants: list, proj_name: str +) -> None: + """Create the metrics dashboard for the project""" + + # Copy the metrics and dashboard files to the project's public folder + metrics_folder = client.my_datasite / "public/fl/{proj_name}/" + metrics_folder.mkdir(parents=True, exist_ok=True) + shutil.copy("./dashboard/index.html", metrics_folder) + shutil.copy("./dashboard/syftbox-sdk.js", metrics_folder) + shutil.copy("./dashboard/index.js", metrics_folder) + + # Create a new participants.json file in the metrics folder + participant_metrics_file = metrics_folder / "participants.json" + create_participant_json_file( + participants, fl_config["rounds"], output_path=participant_metrics_file + ) + + # Copy the accuracy_metrics.json file to the project's metrics folder + shutil.copy("./dashboard/accuracy_metrics.json", metrics_folder) + + print( + f"Dashboard created for the project: {proj_name} at {metrics_folder.resolve()}" + ) + + +def init_project_directory(client: Client, fl_config_json_path: Path) -> None: """ - Initializes the FL project by reading the `fl_config.json` file + Initializes the FL project from the `fl_config.json` file If the project with same name already exists in the `running` folder then it skips creating the project @@ -105,8 +178,9 @@ def initialize_fl_project(client: Client, fl_config_json_path: Path) -> None: └── state.json └── done """ - with open(fl_config_json_path, "r") as f: - fl_config: dict = json.load(f) + + # Read the fl_config.json file + fl_config = read_json(fl_config_json_path) proj_name = str(fl_config["project_name"]) participants = fl_config["participants"] @@ -115,55 +189,41 @@ def initialize_fl_project(client: Client, fl_config_json_path: Path) -> None: running_folder = fl_aggregator / "running" proj_folder = running_folder / proj_name + # If the project already exists, then skip creating the project if proj_folder.is_dir(): - print(f"FL project {proj_name} already exists") + print(f"FL project {proj_name} already exists at: {proj_folder.resolve()}") return - else: - print(f"Creating new FL project {proj_name}") - proj_folder.mkdir(parents=True, exist_ok=True) - fl_clients_folder = proj_folder / "fl_clients" - agg_weights_folder = proj_folder / "agg_weights" - fl_clients_folder.mkdir(parents=True, exist_ok=True) - agg_weights_folder.mkdir(parents=True, exist_ok=True) - - # create the folders for the participants - for participant in participants: - participant_folder = fl_clients_folder / participant - participant_folder.mkdir(parents=True, exist_ok=True) - # TODO: create a custom syft permission for the clients in the `fl_clients` folder - add_public_write_permission(client, participant_folder) - - # Move the config file to the project's running folder - shutil.move(fl_config_json_path, proj_folder) - - # move the model architecture to the project's running folder - model_arch_src = fl_aggregator / "launch" / fl_config["model_arch"] - shutil.move(model_arch_src, proj_folder) - - # copy the global model weights to the project's agg_weights folder as `agg_model_round_0.pt` - # and move the global model weights to the project's running folder - model_weights_src = fl_aggregator / "launch" / fl_config["model_weight"] - shutil.copy(model_weights_src, agg_weights_folder / "agg_model_round_0.pt") - shutil.move(model_weights_src, proj_folder) - - # Copy the metrics dashboard files to the project's public folder - metrics_folder = Path(client.my_datasite) / "public" / "fl" / proj_name - metrics_folder.mkdir(parents=True, exist_ok=True) - shutil.copy("./dashboard/index.html", metrics_folder) - shutil.copy("./dashboard/syftbox-sdk.js", metrics_folder) - shutil.copy("./dashboard/index.js", metrics_folder) - - # Create a new participants.json file in the metrics folder - participant_metrics_file = metrics_folder / "participants.json" - create_participant_json_file( - participants, fl_config["rounds"], output_path=participant_metrics_file - ) - # Copy the accuracy_metrics.json file to the project's metrics folder - shutil.copy("./dashboard/accuracy_metrics.json", metrics_folder) + # Create the project folder + print(f"Creating new FL project {proj_name} at {proj_folder.resolve()}") + proj_folder.mkdir(parents=True, exist_ok=True) + fl_clients_folder = proj_folder / "fl_clients" + agg_weights_folder = proj_folder / "agg_weights" + fl_clients_folder.mkdir(parents=True, exist_ok=True) + agg_weights_folder.mkdir(parents=True, exist_ok=True) + + # create the folders for the participants + for participant in participants: + participant_folder = fl_clients_folder / participant + participant_folder.mkdir(parents=True, exist_ok=True) + + # Give participant write access to the project folder + add_public_write_permission(client, participant_folder) + + # Move the config file to the project's running folder + shutil.move(fl_config_json_path, proj_folder) + + # move the model architecture to the project's running folder + model_arch_src = fl_aggregator / "launch" / fl_config["model_arch"] + shutil.move(model_arch_src, proj_folder) - # TODO: create a state.json file to keep track of the project state - # if needed while running the FL rounds + # copy the global model weights to the project's agg_weights folder as `agg_model_round_0.pt` + # and move the global model weights to the project's running folder + model_weights_src = fl_aggregator / "launch" / fl_config["model_weight"] + shutil.copy(model_weights_src, agg_weights_folder / "agg_model_round_0.pt") + shutil.move(model_weights_src, proj_folder) + + create_metrics_dashboard(client, fl_config, participants, proj_name) def launch_fl_project(client: Client) -> None: @@ -187,14 +247,22 @@ def launch_fl_project(client: Client) -> None: ├── global_model_weights.pt (dragged and dropped by the FL user) ├── mnist_test_dataset.pt """ - launch_folder = client.api_data("fl_aggregator") / "launch" - fl_config_json_path = launch_folder / "fl_config.json" + fl_config_json_path = client.api_data("fl_aggregator/launch/fl_config.json/") + if not fl_config_json_path.is_file(): - print(f"`fl_config.json` not found in the {launch_folder} folder. Skipping...") - return + raise StateNotReady( + f"No launch config found at path: {fl_config_json_path.resolve()}. Skipping current run !!!" + ) - initialize_fl_project(client, fl_config_json_path) + # Validate the fl_config.json file + try: + validate_launch_config(fl_config=fl_config_json_path) + except ValueError as e: + raise StateNotReady("Invalid launch config: " + str(e)) + + # If the config is valid, then create the project + init_project_directory(client, fl_config_json_path) def get_network_participants(client: Client): @@ -222,7 +290,7 @@ def create_fl_client_request(client: Client, proj_folder: Path): """ fl_clients = get_all_directories(proj_folder / "fl_clients") network_participants = get_network_participants(client) - + for fl_client in fl_clients: if fl_client.name not in network_participants: print(f"Client {fl_client.name} is not part of the network") @@ -250,7 +318,6 @@ def check_fl_client_pvt_data_added( ): """Check if the private data is added to the client""" - proj_state = get_client_proj_state(fl_proj_folder) participant_added_data = proj_state.get("dataset_added") @@ -269,8 +336,34 @@ def check_fl_client_pvt_data_added( ) -def check_fl_client_model_training_progress(client: Client, proj_folder: Path): - """Check if model training progress for the client""" +def check_pvt_data_added_by_peer( + peer_name: str, + peer_client_path: Path, + project_name: str, + participant_metrics_file: Path, +): + """Check if the private data is added by the client""" + + fl_proj_folder = peer_client_path / "running" / project_name + proj_state = get_client_proj_state(fl_proj_folder) + + participant_added_data = proj_state.get("dataset_added") + + # Skip if the state file is not present + if participant_added_data is None: + print(f"Private data not added by the client {peer_name}") + return + + update_json( + participant_metrics_file, + peer_name, + ParticipantStateCols.ADDED_PRIVATE_DATA, + participant_added_data, + ) + + +def track_model_train_progress_for_peers(client: Client, proj_folder: Path): + """Track the model training progress for the peer""" fl_clients = get_all_directories(proj_folder / "fl_clients") for fl_client in fl_clients: fl_client_running_folder = client.api_data("fl_client/running", fl_client.name) @@ -291,86 +384,6 @@ def check_fl_client_model_training_progress(client: Client, proj_folder: Path): ) -def check_fl_client_installed(client: Client, proj_folder: Path): - """ - Checks if the client has installed the `fl_client` app - """ - fl_clients = get_all_directories(proj_folder / "fl_clients") - for fl_client in fl_clients: - fl_client_app_path = ( - client.datasites / fl_client.name / "api_data" / "fl_client" - ) - fl_client_request_folder = fl_client_app_path / "request" - fl_client_request_syftperm = fl_client_request_folder / "_.syftperm" - - installed_fl_client_app = True - if not fl_client_request_syftperm.is_file(): - print(f"FL client {fl_client.name} has not installed the app yet") - installed_fl_client_app = False - - participants_metrics_file = get_participants_metric_file(client, proj_folder) - # As they have installed, update the participants.json file with state - update_json( - participants_metrics_file, - fl_client.name, - ParticipantStateCols.FL_CLIENT_INSTALLED, - installed_fl_client_app, - ) - - -def check_proj_requests(client: Client, proj_folder: Path): - """ - Step 1: Checks if the project requests are sent to the clients - Step 2: Checks if all the clients have approved the project - - Note: The clients approve the project when they move from the `request` folder to the `running` folder - - """ - fl_clients = get_all_directories(proj_folder / "fl_clients") - project_unapproved_clients = [] - for fl_client in fl_clients: - fl_client_app_path = ( - client.datasites / fl_client.name / "api_data" / "fl_client" - ) - fl_client_request_folder = fl_client_app_path / "request" / proj_folder.name - fl_client_running_folder = fl_client_app_path / "running" / proj_folder.name - - # If the project is not present in the running folder and the request folder - # create a request folder for the client - if ( - not fl_client_running_folder.is_dir() - and not fl_client_request_folder.is_dir() - ): - print( - f"Request sent to {fl_client.name} for the project {proj_folder.name}" - ) - - if not fl_client_running_folder.is_dir(): - project_unapproved_clients.append(fl_client.name) - else: - # If the project is present in the running folder, update the participants.json file with state - participants_metrics_file = get_participants_metric_file( - client, proj_folder - ) - update_json( - participants_metrics_file, - fl_client.name, - ParticipantStateCols.PROJECT_APPROVED, - True, - ) - - # Check if the private data is added to the client - check_fl_client_pvt_data_added( - fl_client_running_folder, - fl_client.name, - ) - - if project_unapproved_clients: - raise StateNotReady( - f"Project {proj_folder.name} is not approved by the clients {project_unapproved_clients}" - ) - - def load_model_class(model_path: Path, model_class_name: str) -> type: spec = importlib.util.spec_from_file_location(model_path.stem, model_path) model_arch = importlib.util.module_from_spec(spec) @@ -485,15 +498,85 @@ def save_model_accuracy_metrics( # Schema of json files # [ {round: 1, accuracy: 0.98}, {round: 2, accuracy: 0.99} ] # Append the accuracy and round to the json file - with open(metrics_file, "r") as f: - metrics = json.load(f) + metrics = read_json(metrics_file) metrics.append({"round": current_round, "accuracy": accuracy}) - with open(metrics_file, "w") as f: - json.dump(metrics, f) + write_json(metrics_file, metrics) + + +def check_aggregator_added_pvt_data(client: Client, proj_folder: Path): + fl_config = read_json(proj_folder / "fl_config.json") + test_dataset_dir = get_app_private_data(client, "fl_aggregator") + test_dataset_path = test_dataset_dir / fl_config["test_dataset"] + + if not test_dataset_path.exists(): + raise StateNotReady( + f"Test dataset for model evaluation not found, please add the test dataset to: {test_dataset_path.resolve()}" + ) + + +def check_fl_client_app_installed( + peer_name: str, + peer_client_path: Path, + participant_metrics_file: Path, +) -> None: + client_request_folder = peer_client_path / "request" + client_request_syftperm = client_request_folder / "_.syftperm" + installed_fl_client_app = True + if not client_request_syftperm.is_file(): + print(f"FL client {peer_name} has not installed the app yet") + installed_fl_client_app = False -def advance_fl_round(client: Client, proj_folder: Path): + # As they have installed, update the participants.json file with state + update_json( + participant_metrics_file, + peer_name, + ParticipantStateCols.FL_CLIENT_INSTALLED, + installed_fl_client_app, + ) + + +def check_proj_requests_status( + peer_client_path: Path, + peer_name: str, + project_name: str, + participant_metrics_file: Path, +) -> None: + request_folder = peer_client_path / "request" / project_name + running_folder = peer_client_path / "running" / project_name + + if not running_folder.is_dir() and not request_folder.is_dir(): + print(f"Request sent to {peer_name} for the project {project_name}.") + + if running_folder.is_dir(): + update_json( + participant_metrics_file, + peer_name, + ParticipantStateCols.PROJECT_APPROVED, + True, + ) + return True + + return False + + +def share_agg_model_to_peers( + client: Client, + proj_folder: Path, + agg_model_output_path: Path, + participants: list, +): + """Shares the aggregated model to all the participants.""" + for participant in participants: + client_app_path = client.datasites / participant / "api_data" / "fl_client" + client_agg_weights_folder = ( + client_app_path / "running" / proj_folder.name / "agg_weights" + ) + shutil.copy(agg_model_output_path, client_agg_weights_folder) + + +def aggregate_and_evaluate(client: Client, proj_folder: Path): """ 1. Wait for the trained model from the clients 3. Aggregate the trained model and place it in the `agg_weights` folder @@ -503,8 +586,7 @@ def advance_fl_round(client: Client, proj_folder: Path): agg_weights_folder = proj_folder / "agg_weights" current_round = len(list(agg_weights_folder.iterdir())) - with open(proj_folder / "fl_config.json", "r") as f: - fl_config: dict = json.load(f) + fl_config = read_json(proj_folder / "fl_config.json") total_rounds = fl_config["rounds"] if current_round >= total_rounds + 1: @@ -514,15 +596,7 @@ def advance_fl_round(client: Client, proj_folder: Path): participants = fl_config["participants"] - test_dataset_dir = get_app_private_data(client, "fl_aggregator") - test_dataset_path = test_dataset_dir / fl_config["test_dataset"] - - if not test_dataset_path.exists(): - raise StateNotReady( - f"Test dataset not found, please add the test dataset to : {test_dataset_path.resolve()}" - ) - - check_fl_client_model_training_progress(client, proj_folder) + track_model_train_progress_for_peers(client, proj_folder) if current_round == 1: for participant in participants: @@ -569,6 +643,10 @@ def advance_fl_round(client: Client, proj_folder: Path): fl_config, proj_folder, trained_model_paths, current_round ) + # Test dataset for model evaluation + test_dataset_dir = get_app_private_data(client, "fl_aggregator") + test_dataset_path = test_dataset_dir / fl_config["test_dataset"] + # Evaluate the aggregate model model_class = load_model_class( proj_folder / fl_config["model_arch"], fl_config["model_class_name"] @@ -577,15 +655,67 @@ def advance_fl_round(client: Client, proj_folder: Path): model.load_state_dict(torch.load(str(agg_model_output_path), weights_only=True)) accuracy = evaluate_agg_model(model, test_dataset_path) print(f"Accuracy of the aggregated model for round {current_round}: {accuracy}") + + # Save the model accuracy metrics save_model_accuracy_metrics(client, proj_folder, current_round, accuracy) # Send the aggregated model to all the clients - for participant in participants: - client_app_path = client.datasites / participant / "api_data" / "fl_client" - client_agg_weights_folder = ( - client_app_path / "running" / proj_folder.name / "agg_weights" + share_agg_model_to_peers(client, proj_folder, agg_model_output_path, participants) + + +def check_model_aggregation_prerequisites(client: Client, proj_folder: Path) -> None: + """Check if the prerequisites are met before starting model aggregation + + 1. Check if the fl client app is installed for all the peers + 2. Check if the project requests are sent to the peers + 3. Check if all the peers have approved the project + 4. Check if the private data is added by the peers + 5. Check if the test dataset is added by the aggregator + """ + + fl_clients = get_all_directories(proj_folder / "fl_clients") + participant_metrics_file = get_participants_metric_file(client, proj_folder) + peers_with_pending_requests = [] + + for fl_client in fl_clients: + fl_client_app_path = client.datasites / f"{fl_client.name}/api_data/fl_client" + + # Check if the fl client app is installed for given participant + check_fl_client_app_installed( + peer_name=fl_client.name, + peer_client_path=fl_client_app_path, + participant_metrics_file=participant_metrics_file, ) - shutil.copy(agg_model_output_path, client_agg_weights_folder) + + # Check if project request is sent to the client + # and if the client has approved the project + project_approved = check_proj_requests_status( + peer_client_path=fl_client_app_path, + peer_name=fl_client.name, + project_name=proj_folder.name, + participant_metrics_file=participant_metrics_file, + ) + + # If the project is not approved by the client, add it to the list + if not project_approved: + peers_with_pending_requests.append(fl_client.name) + + # Check if the private data is added by the participant + check_pvt_data_added_by_peer( + peer_client_path=fl_client_app_path, + project_name=proj_folder.name, + peer_name=fl_client.name, + participant_metrics_file=participant_metrics_file, + ) + + if peers_with_pending_requests: + raise StateNotReady( + "Project requests are pending for the clients: " + + str(peers_with_pending_requests) + ) + + # Check if the test dataset is added by the aggregator + check_aggregator_added_pvt_data(client, proj_folder) def _advance_fl_project(client: Client, proj_folder: Path) -> None: @@ -600,26 +730,21 @@ def _advance_fl_project(client: Client, proj_folder: Path) -> None: 7. repeat d until all the rounds are complete """ - try: - create_fl_client_request(client, proj_folder) - - check_fl_client_installed(client, proj_folder) + # Create the request folder for the fl clients + create_fl_client_request(client, proj_folder) - check_proj_requests(client, proj_folder) + # Check if the prerequisites are met before starting model aggregation + check_model_aggregation_prerequisites(client, proj_folder) - advance_fl_round(client, proj_folder) - - except StateNotReady as e: - print(e) - return + aggregate_and_evaluate(client, proj_folder) def advance_fl_projects(client: Client) -> None: """ Iterates over the `running` folder and tries to advance the FL projects """ - fl_aggregator_app = client.api_data("fl_aggregator") - running_folder = fl_aggregator_app / "running" + running_folder = client.api_data("fl_aggregator/running") + for proj_folder in running_folder.iterdir(): if proj_folder.is_dir(): proj_name = proj_folder.name @@ -627,17 +752,25 @@ def advance_fl_projects(client: Client) -> None: _advance_fl_project(client, proj_folder) -if __name__ == "__main__": +def start_app(): + """Main function to run the FL Aggregator App""" client = Client.load() # Step 1: Init the FL Aggregator App - init_fl_aggregator_app(client) + init_aggregator(client) - # Step 2: Launch the FL Project - # Iterates over the `launch` folder and creates a new FL project - # if the `fl_config.json` is found in the `launch` folder - launch_fl_project(client) + try: + # Step 2: Launch the FL Project + # Iterates over the `launch` folder and creates a new FL project + # if `fl_config.json` exists in the `launch` folder + launch_fl_project(client) + + # Step 3: Advance the FL Projects. + # Iterates over the running folder and tries to advance the FL project + advance_fl_projects(client) + except StateNotReady as e: + print(e) - # Step 3: Advance the FL Projects. - # Iterates over the running folder and tries to advance the FL project - advance_fl_projects(client) + +if __name__ == "__main__": + start_app() From 04abcd140a6df94b9f668cf90a53ade2d8a0fcb2 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 11:19:15 +0530 Subject: [PATCH 2/7] fix project name folder for metrics tracking --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index bd82edd..f324386 100644 --- a/main.py +++ b/main.py @@ -135,7 +135,7 @@ def create_metrics_dashboard( """Create the metrics dashboard for the project""" # Copy the metrics and dashboard files to the project's public folder - metrics_folder = client.my_datasite / "public/fl/{proj_name}/" + metrics_folder = client.my_datasite / f"public/fl/{proj_name}/" metrics_folder.mkdir(parents=True, exist_ok=True) shutil.copy("./dashboard/index.html", metrics_folder) shutil.copy("./dashboard/syftbox-sdk.js", metrics_folder) From 0857ee85bfc3977358b3238bb3100a99bd60b6e3 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 11:52:31 +0530 Subject: [PATCH 3/7] reuse read and save json from utils --- main.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index f324386..de39b9c 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,13 @@ from torch import nn import torch from torch.utils.data import DataLoader, TensorDataset -from utils import create_participant_json_file, update_json, ParticipantStateCols +from utils import ( + create_participant_json_file, + update_json, + ParticipantStateCols, + read_json, + save_json, +) # TODO: add a syftignore to ignore mnist test dataset from syncing @@ -57,24 +63,6 @@ def get_client_proj_state(project_folder: Path) -> dict: return project_state -def read_json(json_file: Path) -> dict: - """ - Reads the json file and returns the contents - """ - with open(json_file, "r") as f: - data = json.load(f) - - return data - - -def write_json(json_file: Path, data: dict) -> None: - """ - Writes the data to the json file - """ - with open(json_file, "w") as f: - json.dump(data, f, indent=4) - - def validate_launch_config(fl_config: Path) -> bool: """ Validates the `fl_config.json` file @@ -251,9 +239,10 @@ def launch_fl_project(client: Client) -> None: fl_config_json_path = client.api_data("fl_aggregator/launch/fl_config.json/") if not fl_config_json_path.is_file(): - raise StateNotReady( - f"No launch config found at path: {fl_config_json_path.resolve()}. Skipping current run !!!" + print( + f"No launch config found at path: {fl_config_json_path.resolve()}. Skipping !!!" ) + return # Validate the fl_config.json file try: @@ -501,7 +490,7 @@ def save_model_accuracy_metrics( metrics = read_json(metrics_file) metrics.append({"round": current_round, "accuracy": accuracy}) - write_json(metrics_file, metrics) + save_json(metrics, metrics_file) def check_aggregator_added_pvt_data(client: Client, proj_folder: Path): From afdbef13abed1a74b4dd4b3fcec29f5a2753b3ee Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 12:31:03 +0530 Subject: [PATCH 4/7] move few methods from main to utils --- main.py | 108 ++++++++----------------------------------------------- utils.py | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 94 deletions(-) diff --git a/main.py b/main.py index de39b9c..6b737d7 100644 --- a/main.py +++ b/main.py @@ -13,10 +13,13 @@ ParticipantStateCols, read_json, save_json, + is_dir_empty, + load_model_class, + get_all_directories, + get_network_participants, + validate_launch_config, ) -# TODO: add a syftignore to ignore mnist test dataset from syncing - # Exception name to indicate the state cannot advance # as there are some pre-requisites that are not met @@ -24,10 +27,6 @@ class StateNotReady(Exception): pass -# TODO: Currently setting the permissions with public write -# change the permission model later to be more secure -# NOTE: we mainly want the aggregator to have write access to -# fl_aggregator/running/fl_project_name/fl_clients/* def add_public_write_permission(client: Client, path: Path) -> None: """ Adds public write permission to the given path @@ -36,13 +35,6 @@ def add_public_write_permission(client: Client, path: Path) -> None: permission.ensure(path) -def get_all_directories(path: Path) -> list: - """ - Returns the list of directories present in the given path - """ - return [x for x in path.iterdir() if x.is_dir()] - - def get_app_private_data(client: Client, app_name: str) -> Path: """ Returns the private data directory of the app @@ -63,34 +55,11 @@ def get_client_proj_state(project_folder: Path) -> dict: return project_state -def validate_launch_config(fl_config: Path) -> bool: +def get_participants_metric_file(client: Client, proj_folder: Path): """ - Validates the `fl_config.json` file + Returns the path to the participant metrics file """ - - try: - fl_config = read_json(fl_config) - except json.JSONDecodeError: - raise ValueError(f"Invalid JSON format in {fl_config.resolve()}") - - required_keys = [ - "project_name", - "aggregator", - "participants", - "model_arch", - "model_weight", - "model_class_name", - "rounds", - "epoch", - "test_dataset", - "learning_rate", - ] - - for key in required_keys: - if key not in fl_config: - raise ValueError(f"Required key {key} is missing in fl_config.json") - - return True + return client.my_datasite / "public" / "fl" / proj_folder.name / "participants.json" def init_aggregator(client: Client) -> None: @@ -177,8 +146,9 @@ def init_project_directory(client: Client, fl_config_json_path: Path) -> None: running_folder = fl_aggregator / "running" proj_folder = running_folder / proj_name - # If the project already exists, then skip creating the project - if proj_folder.is_dir(): + # If the project already exists and is not empty + # then skip creating the project + if proj_folder.is_dir() and not is_dir_empty(proj_folder): print(f"FL project {proj_name} already exists at: {proj_folder.resolve()}") return @@ -254,25 +224,6 @@ def launch_fl_project(client: Client) -> None: init_project_directory(client, fl_config_json_path) -def get_network_participants(client: Client): - exclude_dir = ["apps", ".syft"] - entries = client.datasites.iterdir() - - users = [] - for entry in entries: - if entry.is_dir() and entry not in exclude_dir: - users.append(entry.name) - - return users - - -def get_participants_metric_file(client: Client, proj_folder: Path): - """ - Returns the path to the participant metrics file - """ - return client.my_datasite / "public" / "fl" / proj_folder.name / "participants.json" - - def create_fl_client_request(client: Client, proj_folder: Path): """ Create the request folder for the fl clients @@ -301,30 +252,6 @@ def create_fl_client_request(client: Client, proj_folder: Path): ) -def check_fl_client_pvt_data_added( - fl_proj_folder: Path, - fl_client_name: str, -): - """Check if the private data is added to the client""" - - proj_state = get_client_proj_state(fl_proj_folder) - - participant_added_data = proj_state.get("dataset_added") - - # Skip if the state file is not present - if participant_added_data is None: - print(f"Private data not added to the client {fl_client_name}") - return - - participants_metrics_file = get_participants_metric_file(client, fl_proj_folder) - update_json( - participants_metrics_file, - fl_client_name, - ParticipantStateCols.ADDED_PRIVATE_DATA, - participant_added_data, - ) - - def check_pvt_data_added_by_peer( peer_name: str, peer_client_path: Path, @@ -373,15 +300,6 @@ def track_model_train_progress_for_peers(client: Client, proj_folder: Path): ) -def load_model_class(model_path: Path, model_class_name: str) -> type: - spec = importlib.util.spec_from_file_location(model_path.stem, model_path) - model_arch = importlib.util.module_from_spec(spec) - spec.loader.exec_module(model_arch) - model_class = getattr(model_arch, model_class_name) - - return model_class - - def aggregate_model(fl_config, proj_folder, trained_model_paths, current_round) -> Path: print("Aggregating the trained models") print(f"Trained model paths: {trained_model_paths}") @@ -538,7 +456,9 @@ def check_proj_requests_status( if not running_folder.is_dir() and not request_folder.is_dir(): print(f"Request sent to {peer_name} for the project {project_name}.") - if running_folder.is_dir(): + # Check if project is approved by the client + # If the running folder is not empty, then the project is a valid project + if running_folder.is_dir() and not is_dir_empty(running_folder): update_json( participant_metrics_file, peer_name, diff --git a/utils.py b/utils.py index 2318ec8..f68b938 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,8 @@ import json from pathlib import Path from enum import Enum +import importlib.util +from syftbox.lib import Client class ParticipantStateCols(Enum): @@ -12,6 +14,10 @@ class ParticipantStateCols(Enum): MODEL_TRAINING_PROGRESS = "Training Progress" +def is_dir_empty(directory: Path): + return not any(directory.iterdir()) + + def read_json(data_path: Path): with open(data_path) as fp: data = json.load(fp) @@ -56,3 +62,61 @@ def update_json( participant[column_name.value] = column_val save_json(participant_history, data_path) + + +def load_model_class(model_path: Path, model_class_name: str) -> type: + spec = importlib.util.spec_from_file_location(model_path.stem, model_path) + model_arch = importlib.util.module_from_spec(spec) + spec.loader.exec_module(model_arch) + model_class = getattr(model_arch, model_class_name) + + return model_class + + +def get_all_directories(path: Path) -> list: + """ + Returns the list of directories present in the given path + """ + return [x for x in path.iterdir() if x.is_dir()] + + +def get_network_participants(client: Client): + exclude_dir = ["apps", ".syft"] + entries = client.datasites.iterdir() + + users = [] + for entry in entries: + if entry.is_dir() and entry not in exclude_dir: + users.append(entry.name) + + return users + + +def validate_launch_config(fl_config: Path) -> bool: + """ + Validates the `fl_config.json` file + """ + + try: + fl_config = read_json(fl_config) + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON format in {fl_config.resolve()}") + + required_keys = [ + "project_name", + "aggregator", + "participants", + "model_arch", + "model_weight", + "model_class_name", + "rounds", + "epoch", + "test_dataset", + "learning_rate", + ] + + for key in required_keys: + if key not in fl_config: + raise ValueError(f"Required key {key} is missing in fl_config.json") + + return True From 31fabbc50dffcdd04b2c5070619c26570be98a96 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 12:39:49 +0530 Subject: [PATCH 5/7] isort + docstrings --- main.py | 41 ++++++++++++++++++++++++++--------------- utils.py | 5 +++-- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/main.py b/main.py index 6b737d7..0311586 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,21 @@ -import importlib.util -from syftbox.lib import Client -from syftbox.lib import SyftPermission -from pathlib import Path -import json import shutil -from torch import nn +from pathlib import Path + import torch +from syftbox.lib import Client, SyftPermission +from torch import nn from torch.utils.data import DataLoader, TensorDataset + from utils import ( - create_participant_json_file, - update_json, ParticipantStateCols, - read_json, - save_json, - is_dir_empty, - load_model_class, + create_participant_json_file, get_all_directories, get_network_participants, + is_dir_empty, + load_model_class, + read_json, + save_json, + update_json, validate_launch_config, ) @@ -226,8 +225,11 @@ def launch_fl_project(client: Client) -> None: def create_fl_client_request(client: Client, proj_folder: Path): """ - Create the request folder for the fl clients + Create the request folder for the fl clients. + Creates a request folder for each client in the project's fl_clients folder + and copies the fl_config.json and model_arch.py to the request folder. """ + fl_clients = get_all_directories(proj_folder / "fl_clients") network_participants = get_network_participants(client) @@ -258,7 +260,7 @@ def check_pvt_data_added_by_peer( project_name: str, participant_metrics_file: Path, ): - """Check if the private data is added by the client""" + """Check if the private data is added by the client for model training.""" fl_proj_folder = peer_client_path / "running" / project_name proj_state = get_client_proj_state(fl_proj_folder) @@ -279,7 +281,7 @@ def check_pvt_data_added_by_peer( def track_model_train_progress_for_peers(client: Client, proj_folder: Path): - """Track the model training progress for the peer""" + """Track the model training progress for the peer.""" fl_clients = get_all_directories(proj_folder / "fl_clients") for fl_client in fl_clients: fl_client_running_folder = client.api_data("fl_client/running", fl_client.name) @@ -301,6 +303,7 @@ def track_model_train_progress_for_peers(client: Client, proj_folder: Path): def aggregate_model(fl_config, proj_folder, trained_model_paths, current_round) -> Path: + """Aggregate the trained models from the clients and save the aggregated model""" print("Aggregating the trained models") print(f"Trained model paths: {trained_model_paths}") global_model_class = load_model_class( @@ -359,6 +362,7 @@ def shift_project_to_done_folder( def evaluate_agg_model(agg_model: nn.Module, dataset_path: Path) -> float: + """Evaluate the aggregated model using the test dataset. We use accuracy as the evaluation metric.""" agg_model.eval() # load the saved mnist subset @@ -412,6 +416,10 @@ def save_model_accuracy_metrics( def check_aggregator_added_pvt_data(client: Client, proj_folder: Path): + """Check if the aggregator has added the test dataset for model evaluation. + + Test dataset location: `api_data/fl_aggregator/private/.pt` + """ fl_config = read_json(proj_folder / "fl_config.json") test_dataset_dir = get_app_private_data(client, "fl_aggregator") test_dataset_path = test_dataset_dir / fl_config["test_dataset"] @@ -427,6 +435,8 @@ def check_fl_client_app_installed( peer_client_path: Path, participant_metrics_file: Path, ) -> None: + """Check if the FL client app is installed for the given participant.""" + client_request_folder = peer_client_path / "request" client_request_syftperm = client_request_folder / "_.syftperm" @@ -450,6 +460,7 @@ def check_proj_requests_status( project_name: str, participant_metrics_file: Path, ) -> None: + """Check if the project requests are sent to the clients and if the clients have approved the project.""" request_folder = peer_client_path / "request" / project_name running_folder = peer_client_path / "running" / project_name diff --git a/utils.py b/utils.py index f68b938..4de709c 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,8 @@ +import importlib.util import json -from pathlib import Path from enum import Enum -import importlib.util +from pathlib import Path + from syftbox.lib import Client From 398619c8b3819aaf128d66a11daad6888611d5ed Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 13:19:00 +0530 Subject: [PATCH 6/7] check for empty dirs to invalid project approval --- main.py | 10 +++++++--- utils.py | 6 ++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 0311586..ef1c254 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ create_participant_json_file, get_all_directories, get_network_participants, - is_dir_empty, + has_empty_dirs, load_model_class, read_json, save_json, @@ -99,6 +99,10 @@ def create_metrics_dashboard( # Create a new participants.json file in the metrics folder participant_metrics_file = metrics_folder / "participants.json" + + # Remove the existing participants.json file if it exists + participant_metrics_file.unlink(missing_ok=True) + create_participant_json_file( participants, fl_config["rounds"], output_path=participant_metrics_file ) @@ -147,7 +151,7 @@ def init_project_directory(client: Client, fl_config_json_path: Path) -> None: # If the project already exists and is not empty # then skip creating the project - if proj_folder.is_dir() and not is_dir_empty(proj_folder): + if proj_folder.is_dir() and not has_empty_dirs(proj_folder): print(f"FL project {proj_name} already exists at: {proj_folder.resolve()}") return @@ -469,7 +473,7 @@ def check_proj_requests_status( # Check if project is approved by the client # If the running folder is not empty, then the project is a valid project - if running_folder.is_dir() and not is_dir_empty(running_folder): + if running_folder.is_dir() and not has_empty_dirs(running_folder): update_json( participant_metrics_file, peer_name, diff --git a/utils.py b/utils.py index 4de709c..9432829 100644 --- a/utils.py +++ b/utils.py @@ -15,6 +15,12 @@ class ParticipantStateCols(Enum): MODEL_TRAINING_PROGRESS = "Training Progress" +def has_empty_dirs(directory: Path): + return any( + subdir.is_dir() and is_dir_empty(subdir) for subdir in directory.iterdir() + ) + + def is_dir_empty(directory: Path): return not any(directory.iterdir()) From 612b3ff54c067199cf93cda4b513795bed66beaf Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 22 Nov 2024 19:01:52 +0530 Subject: [PATCH 7/7] fix typing --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index ef1c254..4c83b27 100644 --- a/main.py +++ b/main.py @@ -209,7 +209,7 @@ def launch_fl_project(client: Client) -> None: ├── mnist_test_dataset.pt """ - fl_config_json_path = client.api_data("fl_aggregator/launch/fl_config.json/") + fl_config_json_path = client.api_data("fl_aggregator/launch/fl_config.json") if not fl_config_json_path.is_file(): print( @@ -463,7 +463,7 @@ def check_proj_requests_status( peer_name: str, project_name: str, participant_metrics_file: Path, -) -> None: +) -> bool: """Check if the project requests are sent to the clients and if the clients have approved the project.""" request_folder = peer_client_path / "request" / project_name running_folder = peer_client_path / "running" / project_name