diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index 0b89684df1..6a4fa6da5e 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -85,6 +85,19 @@ def wrapper(self, *args, **kwargs): return wrapper +def _resend_data_on_reconnection(func): + def wrapper(self, *args, **kwargs): + while True: + try: + response = func(self, *args, **kwargs) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNKNOWN: + self.logger.info(f'Attempting to resend data request to aggregator at {self.uri}') + continue + break + return response + + return wrapper class AggregatorGRPCClient: """Client to the aggregator over gRPC-TLS.""" @@ -258,6 +271,7 @@ def reconnect(self): ) @_atomic_connection + @_resend_data_on_reconnection def get_tasks(self, collaborator_name): """Get tasks from the aggregator.""" self._set_header(collaborator_name) @@ -268,10 +282,12 @@ def get_tasks(self, collaborator_name): return response.tasks, response.round_number, response.sleep_time, response.quit @_atomic_connection + @_resend_data_on_reconnection def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, report, tags, require_lossless): """Get aggregated tensor from the aggregator.""" self._set_header(collaborator_name) + request = aggregator_pb2.GetAggregatedTensorRequest( header=self.header, tensor_name=tensor_name, @@ -287,6 +303,7 @@ def get_aggregated_tensor(self, collaborator_name, tensor_name, round_number, return response.tensor @_atomic_connection + @_resend_data_on_reconnection def send_local_task_results(self, collaborator_name, round_number, task_name, data_size, named_tensors): """Send task results to the aggregator.""" diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index cb19a56d77..b443ee0650 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -205,8 +205,11 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 context: The gRPC context """ - proto = aggregator_pb2.TaskResults() - proto = utils.datastream_to_proto(proto, request) + try: + proto = aggregator_pb2.TaskResults() + proto = utils.datastream_to_proto(proto, request) + except RuntimeError: + raise RuntimeError('Empty stream message, reestablishing connection from client to resume training...') self.validate_collaborator(proto, context) # all messages get sanity checked @@ -223,7 +226,7 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 return aggregator_pb2.SendLocalTaskResultsResponse( header=self.get_header(collaborator_name) ) - + def get_server(self): """Return gRPC server.""" self.server = server(ThreadPoolExecutor(max_workers=cpu_count()),