diff --git a/main.py b/main.py index befde56..797f742 100644 --- a/main.py +++ b/main.py @@ -51,6 +51,8 @@ def get_client_proj_state(project_folder: Path) -> dict: project_state = {} project_state_file = project_folder / "state/state.json/" + print("Project state file", project_state_file.resolve()) + if project_state_file.is_file(): project_state = json.load(project_state_file.open()) @@ -240,27 +242,31 @@ def create_fl_client_request(client: Client, proj_folder: Path): ) -def check_fl_client_pvt_data_added(client: Client, proj_folder: Path): +def check_fl_client_pvt_data_added( + fl_proj_folder: Path, + fl_client_name: str, +): """Check if the private data is added to the client""" - fl_clients = get_all_directories(proj_folder / "fl_clients") - for fl_client in fl_clients: - fl_client_running_folder = client.api_data("fl_client/running", fl_client.name) - fl_proj_folder = fl_client_running_folder / proj_folder.name - proj_state = get_client_proj_state(fl_proj_folder) - participant_added_data = proj_state.get("dataset_added") + print("Checking for private data added to the client") - # Skip if the state file is not present - if participant_added_data is None: - return + proj_state = get_client_proj_state(fl_proj_folder) - participants_metrics_file = get_participants_metric_file(client, proj_folder) - update_json( - participants_metrics_file, - fl_client.name, - ParticipantStateCols.ADDED_PRIVATE_DATA, - participant_added_data, - ) + participant_added_data = proj_state.get("dataset_added") + + # Skip if the state file is not present + if participant_added_data is None: + return + + print(f"{fl_client_name} added the private data: {participant_added_data}") + + participants_metrics_file = get_participants_metric_file(client, fl_proj_folder) + update_json( + participants_metrics_file, + fl_client_name, + ParticipantStateCols.ADDED_PRIVATE_DATA, + participant_added_data, + ) def check_fl_client_model_training_progress(client: Client, proj_folder: Path): @@ -356,6 +362,12 @@ def check_proj_requests(client: Client, proj_folder: Path): True, ) + # Check if the private data is added to the client + check_fl_client_pvt_data_added( + fl_client_running_folder, + fl_client.name, + ) + if project_unapproved_clients: raise StateNotReady( f"Project {proj_folder.name} is not approved by the clients {project_unapproved_clients}" @@ -598,8 +610,6 @@ def _advance_fl_project(client: Client, proj_folder: Path) -> None: check_proj_requests(client, proj_folder) - check_fl_client_pvt_data_added(client, proj_folder) - advance_fl_round(client, proj_folder) except StateNotReady as e: