Skip to content

Commit

Permalink
Merge pull request #1 from OpenMined/shubham/track-state
Browse files Browse the repository at this point in the history
Track Proj state from Fl Client
  • Loading branch information
rasswanth-s authored Nov 18, 2024
2 parents ef4b909 + 52975fc commit f3f95a4
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 41 deletions.
16 changes: 0 additions & 16 deletions dashboard/accuracy_metrics.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,5 @@
{
"accuracy": 0,
"round": 0
},
{
"accuracy": 0.2,
"round": 1
},
{
"accuracy": 0.2,
"round": 2
},
{
"accuracy": 0.2,
"round": 3
},
{
"accuracy": 0.4,
"round": 4
}
]
90 changes: 81 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ def get_app_private_data(client: Client, app_name: str) -> Path:
return client.workspace.data_dir / "private" / app_name


def get_client_proj_state(project_folder: Path) -> dict:
"""
Returns the path to the state.json file for the project
"""
project_state = {}
project_state_file = project_folder / "state/state.json/"

print("Project state file", project_state_file.resolve())

if project_state_file.is_file():
project_state = json.load(project_state_file.open())

return project_state


def init_fl_aggregator_app(client: Client) -> None:
"""
Creates the `fl_aggregator` app in the `api_data` folder
Expand Down Expand Up @@ -161,7 +176,7 @@ def launch_fl_project(client: Client) -> None:
Example:
- Manually Copy the `fl_config.json`, `model_arch.py`, `global_model_weights.pt`
- Manually Copy the `fl_config.json`, `model_arch.py`, `global_model_weights.pt`
and `mnist_test_dataset.pt` to the `launch` folder
api_data
└── fl_aggregator
Expand Down Expand Up @@ -222,7 +237,58 @@ def create_fl_client_request(client: Client, proj_folder: Path):
# Copy the fl_config.json, model_arch.py to the request folder
shutil.copy(proj_folder / "fl_config.json", fl_client_request_folder)
shutil.copy(proj_folder / "model_arch.py", fl_client_request_folder)
print(f"Sending request to {fl_client.name} for the project {proj_folder.name}")
print(
f"Sending request to {fl_client.name} for the project {proj_folder.name}"
)


def check_fl_client_pvt_data_added(
fl_proj_folder: Path,
fl_client_name: str,
):
"""Check if the private data is added to the client"""

print("Checking for private data added to the client")

proj_state = get_client_proj_state(fl_proj_folder)

participant_added_data = proj_state.get("dataset_added")

# Skip if the state file is not present
if participant_added_data is None:
return

print(f"{fl_client_name} added the private data: {participant_added_data}")

participants_metrics_file = get_participants_metric_file(client, fl_proj_folder)
update_json(
participants_metrics_file,
fl_client_name,
ParticipantStateCols.ADDED_PRIVATE_DATA,
participant_added_data,
)


def check_fl_client_model_training_progress(client: Client, proj_folder: Path):
"""Check if model training progress for the client"""
fl_clients = get_all_directories(proj_folder / "fl_clients")
for fl_client in fl_clients:
fl_client_running_folder = client.api_data("fl_client/running", fl_client.name)
fl_proj_folder = fl_client_running_folder / proj_folder.name
proj_state = get_client_proj_state(fl_proj_folder)
model_train_progress = proj_state.get("model_train_progress")

# Skip if the state file is not present
if model_train_progress is None:
return

participants_metrics_file = get_participants_metric_file(client, proj_folder)
update_json(
participants_metrics_file,
fl_client.name,
ParticipantStateCols.MODEL_TRAINING_PROGRESS,
model_train_progress,
)


def check_fl_client_installed(client: Client, proj_folder: Path):
Expand All @@ -236,7 +302,7 @@ def check_fl_client_installed(client: Client, proj_folder: Path):
raise StateNotReady(f"Client {fl_client.name} is not part of the network")

fl_client_app_path = (
client.datasites / fl_client.name / "api_data" / "fl_client"
client.datasites / fl_client.name / "api_data" / "fl_client"
)
fl_client_request_folder = fl_client_app_path / "request"
fl_client_request_syftperm = fl_client_request_folder / "_.syftperm"
Expand Down Expand Up @@ -296,6 +362,12 @@ def check_proj_requests(client: Client, proj_folder: Path):
True,
)

# Check if the private data is added to the client
check_fl_client_pvt_data_added(
fl_client_running_folder,
fl_client.name,
)

if project_unapproved_clients:
raise StateNotReady(
f"Project {proj_folder.name} is not approved by the clients {project_unapproved_clients}"
Expand Down Expand Up @@ -449,15 +521,15 @@ def advance_fl_round(client: Client, proj_folder: Path):
test_dataset_path = test_dataset_dir / fl_config["test_dataset"]

if not test_dataset_path.exists():
StateNotReady(
raise StateNotReady(
f"Test dataset not found, please add the test dataset to : {test_dataset_path.resolve()}"
)

check_fl_client_model_training_progress(client, proj_folder)

if current_round == 1:
for participant in participants:
client_app_path = (
client.datasites / participant / "api_data" / "fl_client"
)
client_app_path = client.datasites / participant / "api_data" / "fl_client"
client_agg_weights_folder = (
client_app_path / "running" / proj_folder.name / "agg_weights"
)
Expand Down Expand Up @@ -533,9 +605,9 @@ def _advance_fl_project(client: Client, proj_folder: Path) -> None:

try:
create_fl_client_request(client, proj_folder)

check_fl_client_installed(client, proj_folder)

check_proj_requests(client, proj_folder)

advance_fl_round(client, proj_folder)
Expand Down
45 changes: 29 additions & 16 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
from pathlib import Path
from enum import Enum


class ParticipantStateCols(Enum):
EMAIL = 'Email'
FL_CLIENT_INSTALLED = 'Fl Client Installed'
PROJECT_APPROVED = 'Project Approved'
HAS_DATA = 'Has Data'
ROUND = 'Round (current/total)'
EMAIL = "Email"
FL_CLIENT_INSTALLED = "Fl Client Installed"
PROJECT_APPROVED = "Project Approved"
ADDED_PRIVATE_DATA = "Added Private Data"
ROUND = "Round (current/total)"
MODEL_TRAINING_PROGRESS = "Training Progress"


def read_json(data_path: Path):
with open(data_path) as fp:
data = json.load(fp)
data = json.load(fp)
return data


Expand All @@ -20,26 +23,36 @@ def save_json(data: dict, data_path: Path):
json.dump(data, fp, indent=4)


def create_participant_json_file(participants: list, total_rounds: int, output_path: Path):
def create_participant_json_file(
participants: list, total_rounds: int, output_path: Path
):
data = []
for participant in participants:
data.append({
ParticipantStateCols.EMAIL.value : participant,
ParticipantStateCols.FL_CLIENT_INSTALLED.value : False,
ParticipantStateCols.PROJECT_APPROVED.value : False,
ParticipantStateCols.HAS_DATA.value : False,
ParticipantStateCols.ROUND.value : f'0/{total_rounds}'
})
data.append(
{
ParticipantStateCols.EMAIL.value: participant,
ParticipantStateCols.FL_CLIENT_INSTALLED.value: False,
ParticipantStateCols.PROJECT_APPROVED.value: False,
ParticipantStateCols.ADDED_PRIVATE_DATA.value: False,
ParticipantStateCols.ROUND.value: f"0/{total_rounds}",
ParticipantStateCols.MODEL_TRAINING_PROGRESS.value: "N/A",
}
)

save_json(data=data, data_path=output_path)

def update_json(data_path: Path, participant_email: str, column_name: ParticipantStateCols, column_val: str):

def update_json(
data_path: Path,
participant_email: str,
column_name: ParticipantStateCols,
column_val: str,
):
if column_name not in ParticipantStateCols:
return
participant_history = read_json(data_path=data_path)
for participant in participant_history:
if participant[ParticipantStateCols.EMAIL.value] == participant_email:
participant[column_name.value] = column_val

save_json(participant_history, data_path)

0 comments on commit f3f95a4

Please # to comment.