diff --git a/.gitignore b/.gitignore index e13041af29..4398b46be1 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ results/* **MNIST/ **cert/ .history/ +.DS_Store diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 301818cd83..008962bbe9 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -157,38 +157,71 @@ def ping(self): """Ping the Aggregator.""" self.client.ping() - def run(self): - """Run the collaborator.""" - # Experiment begin - self.callbacks.on_experiment_begin() + def run(self) -> None: + """ + Run the collaborator main loop. + + Handles experiment lifecycle, round execution, and error logging. + """ + try: + self.callbacks.on_experiment_begin() + self._execute_collaborator_rounds() + self.callbacks.on_experiment_end() + logger.info("Received shutdown signal. Exiting...") + except Exception as experiment_error: + logger.critical( + f"Critical error in collaborator execution. Error: {experiment_error}", + exc_info=True, + ) + self.callbacks.on_experiment_end({"error": str(experiment_error)}) + logger.critical("Collaborator is shutting down due to critical error.") + raise RuntimeError("Collaborator execution failed") from experiment_error + + def _execute_collaborator_rounds(self) -> None: + """ + Execute rounds until a shutdown signal is received. + Each round consists of receiving tasks, executing them, and reporting results. + If any task fails, the round is aborted and the error is logged. + """ while True: tasks, round_num, sleep_time, time_to_quit = self.client.get_tasks() - if time_to_quit: break - if not tasks: sleep(sleep_time) continue + try: + logger.info("Round: %d Received Tasks: %s", round_num, tasks) + self.callbacks.on_round_begin(round_num) + logs = self._execute_round_tasks(tasks, round_num) + self.tensor_db.clean_up(self.db_store_rounds) + self.callbacks.on_round_end(round_num, logs) + except Exception as round_error: + logger.error( + f"Error during round {round_num} execution. Error: {round_error}", exc_info=True + ) + sleep(sleep_time or 10) - # Round begin - logger.info("Round: %d Received Tasks: %s", round_num, tasks) - self.callbacks.on_round_begin(round_num) + def _execute_round_tasks(self, tasks: list, round_number: int) -> dict: + """ + Execute all tasks in a round. - # Run tasks - logs = {} - for task in tasks: - metrics = self.do_task(task, round_num) - logs.update(metrics) + Args: + tasks: List of tasks to execute. + round_number: Current round number. - # Round end - self.tensor_db.clean_up(self.db_store_rounds) - self.callbacks.on_round_end(round_num, logs) + Returns: + Dictionary of logs/metrics from task execution. - # Experiment end - self.callbacks.on_experiment_end() - logger.info("Received shutdown signal. Exiting...") + Raises: + Exception: If any task execution fails, aborts the round. + """ + logs = {} + for task in tasks: + metrics = self.do_task(task, round_number) + logs.update(metrics) + return logs def do_task(self, task, round_number) -> dict: """Perform the specified task. @@ -270,97 +303,103 @@ def do_task(self, task, round_number) -> dict: return metrics - def get_data_for_tensorkey(self, tensor_key): - """Resolve the tensor corresponding to the requested tensorkey. + def get_data_for_tensorkey(self, tensor_key) -> object: + """ + Resolve and return the tensor for the requested TensorKey. + + This function checks the local cache, previous rounds, and the aggregator as needed. Args: - tensor_key (namedtuple): Tensorkey that will be resolved locally or - remotely. May be the product of other tensors. + tensor_key: The TensorKey to resolve. Returns: - nparray: The decompressed tensor associated with the requested - tensor key. + The decompressed tensor associated with the requested tensor key. + + Raises: + Exception: If the tensor cannot be retrieved or reconstructed. """ - # try to get from the store tensor_name, origin, round_number, report, tags = tensor_key logger.debug("Attempting to retrieve tensor %s from local store", tensor_key) - nparray = self.tensor_db.get_tensor_from_cache(tensor_key) - - # if None and origin is our client, request it from the client - if nparray is None: - if origin == self.collaborator_name: - logger.info( - f"Attempting to find locally stored {tensor_name} tensor from prior round..." - ) - prior_round = round_number - 1 - while prior_round >= 0: - nparray = self.tensor_db.get_tensor_from_cache( - TensorKey(tensor_name, origin, prior_round, report, tags) + try: + nparray = self.tensor_db.get_tensor_from_cache(tensor_key) + if nparray is None: + if origin == self.collaborator_name: + logger.info( + f"Attempting to find locally stored {tensor_name} " + f"tensor from prior round..." ) - if nparray is not None: - logger.debug( - f"Found tensor {tensor_name} in local TensorDB for round {prior_round}" + prior_round = round_number - 1 + while prior_round >= 0: + nparray = self.tensor_db.get_tensor_from_cache( + TensorKey(tensor_name, origin, prior_round, report, tags) ) - return nparray - prior_round -= 1 - logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...") - # Determine whether there are additional compression related - # dependencies. - # Typically, dependencies are only relevant to model layers - tensor_dependencies = self.tensor_codec.find_dependencies( - tensor_key, self.use_delta_updates - ) - logger.debug( - "Unable to get tensor from local store..." - "attempting to retrieve from client len tensor_dependencies" - f" tensor_key {tensor_key}" - ) - if len(tensor_dependencies) > 0: - # Resolve dependencies - # tensor_dependencies[0] corresponds to the prior version - # of the model. - # If it exists locally, should pull the remote delta because - # this is the least costly path - prior_model_layer = self.tensor_db.get_tensor_from_cache(tensor_dependencies[0]) - if prior_model_layer is not None: - uncompressed_delta = self.get_aggregated_tensor_from_aggregator( - tensor_dependencies[1] - ) - new_model_tk, nparray = self.tensor_codec.apply_delta( - tensor_dependencies[1], - uncompressed_delta, - prior_model_layer, - creates_model=True, + if nparray is not None: + logger.debug( + f"Found tensor {tensor_name} in local TensorDB " + f"for round {prior_round}" + ) + return nparray + prior_round -= 1 + logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...") + # Determine whether there are additional compression related + # dependencies. + # Typically, dependencies are only relevant to model layers + tensor_dependencies = self.tensor_codec.find_dependencies( + tensor_key, self.use_delta_updates + ) + logger.debug( + "Unable to get tensor from local store..." + "attempting to retrieve from client len tensor_dependencies" + f" tensor_key {tensor_key}" + ) + if len(tensor_dependencies) > 0: + # Resolve dependencies + # tensor_dependencies[0] corresponds to the prior version + # of the model. + # If it exists locally, should pull the remote delta because + # this is the least costly path + prior_model_layer = self.tensor_db.get_tensor_from_cache(tensor_dependencies[0]) + if prior_model_layer is not None: + uncompressed_delta = self.get_aggregated_tensor_from_aggregator( + tensor_dependencies[1] + ) + new_model_tk, nparray = self.tensor_codec.apply_delta( + tensor_dependencies[1], + uncompressed_delta, + prior_model_layer, + creates_model=True, + ) + self.tensor_db.cache_tensor({new_model_tk: nparray}) + else: + logger.info( + "Could not find previous model layer. " + "Fetching latest layer from aggregator" + ) + nparray = self.get_aggregated_tensor_from_aggregator( + tensor_key, require_lossless=True + ) + elif "model" in tags: + nparray = self.get_aggregated_tensor_from_aggregator( + tensor_key, require_lossless=True ) - self.tensor_db.cache_tensor({new_model_tk: nparray}) else: + tensor_name, origin, round_number, report, tags = tensor_key + tags = (self.collaborator_name,) + tags + tensor_key = (tensor_name, origin, round_number, report, tags) logger.info( - "Could not find previous model layer.Fetching latest layer from aggregator" + "Could not find previous model layer." + f"Fetching latest layer from aggregator {tensor_key}" ) - # The original model tensor should be fetched from aggregator nparray = self.get_aggregated_tensor_from_aggregator( tensor_key, require_lossless=True ) - elif "model" in tags: - # Pulling the model for the first time - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) else: - # we should try fetching the tensor from aggregator - tensor_name, origin, round_number, report, tags = tensor_key - tags = (self.collaborator_name,) + tags - tensor_key = (tensor_name, origin, round_number, report, tags) - logger.info( - "Could not find previous model layer." - f"Fetching latest layer from aggregator {tensor_key}" - ) - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) - else: - logger.debug("Found tensor %s in local TensorDB", tensor_key) - + logger.debug("Found tensor %s in local TensorDB", tensor_key) + except Exception as get_tensor_error: + logger.error( + f"Error retrieving tensor {tensor_key}. Error: {get_tensor_error}", exc_info=True + ) + raise return nparray def get_aggregated_tensor_from_aggregator(self, tensor_key, require_lossless=False): diff --git a/openfl/interface/collaborator.py b/openfl/interface/collaborator.py index fe774a216b..7318152475 100644 --- a/openfl/interface/collaborator.py +++ b/openfl/interface/collaborator.py @@ -62,26 +62,34 @@ def collaborator(context): help="The certified common name of the collaborator.", ) def start_(plan, collaborator_name, data_config): - """Starts a collaborator service.""" + """ + Starts a collaborator service. + + Args: + plan: Path to the FL plan YAML file. + collaborator_name: The certified common name of the collaborator. + data_config: Path to the dataset shard configuration file. + """ if plan and is_directory_traversal(plan): echo("Federated learning plan path is out of the openfl workspace scope.") sys.exit(1) if data_config and is_directory_traversal(data_config): echo("The data set/shard configuration file path is out of the openfl workspace scope.") sys.exit(1) - - plan_obj = Plan.parse( - plan_config_path=Path(plan).absolute(), - data_config_path=Path(data_config).absolute(), - ) - - # TODO: Need to restructure data loader config file loader - logger.info(f"Data paths: {plan_obj.cols_data_paths}") - echo(f"Data = {plan_obj.cols_data_paths}") - logger.info("🧿 Starting a Collaborator Service.") - - collaborator = plan_obj.get_collaborator(collaborator_name) - collaborator.run() + try: + plan_obj = Plan.parse( + plan_config_path=Path(plan).absolute(), + data_config_path=Path(data_config).absolute(), + ) + logger.info(f"Data paths: {plan_obj.cols_data_paths}") + echo(f"Data = {plan_obj.cols_data_paths}") + logger.info("🧿 Starting a Collaborator Service.") + collaborator = plan_obj.get_collaborator(collaborator_name) + collaborator.run() + except Exception as e: + logger.critical(f"Critical error starting or running collaborator: {e}", exc_info=True) + echo(style(f"Collaborator failed with error: {e}", fg="red")) + sys.exit(1) @collaborator.command(name="ping")