diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_plot.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_plot.png new file mode 100644 index 000000000..634b9d222 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_plot.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_100_0_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_100_0_0_stickiness_0.4.png new file mode 100644 index 000000000..24483bde8 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_100_0_0_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20.png new file mode 100644 index 000000000..d7a50ed59 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.4.png new file mode 100644 index 000000000..0937378af Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.9.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.9.png new file mode 100644 index 000000000..f9843ec06 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_50_30_20_stickiness_0.9.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_90_10_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_90_10_0_stickiness_0.4.png new file mode 100644 index 000000000..c15cb4f13 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_two_lines_90_10_0_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/availability_validation.png b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_validation.png new file mode 100644 index 000000000..a14f73248 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/availability_validation.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/corrected_validation.png b/lib/python/examples/fwdllm/expts/run_tc_expts/corrected_validation.png new file mode 100644 index 000000000..95d2d9a2b Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/corrected_validation.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_50_eval_30_unavail_20.png b/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_50_eval_30_unavail_20.png new file mode 100644 index 000000000..6f9d2f61f Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_50_eval_30_unavail_20.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_90_eval_10_unavail_0.png b/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_90_eval_10_unavail_0.png new file mode 100644 index 000000000..7a0b2fe59 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/state_distribution_train_90_eval_10_unavail_0.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition.png new file mode 100644 index 000000000..18db75201 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_100_0_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_100_0_0_stickiness_0.4.png new file mode 100644 index 000000000..da6f0d212 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_100_0_0_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.4.png new file mode 100644 index 000000000..81f496743 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.9.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.9.png new file mode 100644 index 000000000..b2e440989 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_50_30_20_stickiness_0.9.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_90_10_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_90_10_0_stickiness_0.4.png new file mode 100644 index 000000000..a492b1fd7 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_stacked_composition_90_10_0_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution.png new file mode 100644 index 000000000..6ffde6aae Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_100_0_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_100_0_0_stickiness_0.4.png new file mode 100644 index 000000000..03b82981f Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_100_0_0_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.4.png new file mode 100644 index 000000000..7e5ffcfc0 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.4.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.9.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.9.png new file mode 100644 index 000000000..32e581fc1 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_50_30_20_stickiness_0.9.png differ diff --git a/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_90_10_0_stickiness_0.4.png b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_90_10_0_stickiness_0.4.png new file mode 100644 index 000000000..3b41069d7 Binary files /dev/null and b/lib/python/examples/fwdllm/expts/run_tc_expts/trainer_state_distribution_90_10_0_stickiness_0.4.png differ diff --git a/lib/python/flame/config.py b/lib/python/flame/config.py index 07dbc3702..9dd26fe9b 100644 --- a/lib/python/flame/config.py +++ b/lib/python/flame/config.py @@ -78,6 +78,7 @@ class SelectorType(str, Enum): FEDBUFF = "fedbuff" OORT = "oort" ASYNC_OORT = "async_oort" + ASYNC_RANDOM = "async_random" class DataSamplerType(str, Enum): diff --git a/lib/python/flame/selector/async_random.py b/lib/python/flame/selector/async_random.py new file mode 100644 index 000000000..47229e805 --- /dev/null +++ b/lib/python/flame/selector/async_random.py @@ -0,0 +1,819 @@ +"""OortSelector class.""" + +import logging +import math +import random +import time +from datetime import timedelta +from collections import deque + +from flame.config import TrainerAvailState +import numpy as np +from flame.channel import ( + KEY_CH_SELECT_REQUESTER, + KEY_CH_STATE, + VAL_CH_STATE_RECV, + VAL_CH_STATE_SEND, +) +from flame.common.typing import Scalar +from flame.common.util import MLFramework, get_ml_framework_in_use +from flame.end import KEY_END_STATE, VAL_END_STATE_NONE, VAL_END_STATE_RECVD, End +from flame.selector import AbstractSelector, SelectorReturnType + +logger = logging.getLogger(__name__) + +SEND_TIMEOUT_WAIT_S = 90 # 90 seconds timeout + +PROP_END_ID = "end_id" +PROP_STAT_UTILITY = "stat_utility" +PROP_AVL_STATE = "avl_state" + + +class AsyncRandomSelector(AbstractSelector): + """A AsyncFL selector class based on Oort.""" + + def __init__(self, **kwargs): + """Initailize instance.""" + super().__init__(**kwargs) + + ml_framework_in_use = get_ml_framework_in_use() + if ml_framework_in_use != MLFramework.PYTORCH: + raise NotImplementedError( + "FedBalancer is currently only implemented in PyTorch;" + ) + + self.round = 0 + self.is_async = True + + try: + self.c = kwargs["c"] + except KeyError: + raise KeyError("c (concurrency level) is not specified in config") + + try: + self.agg_goal = kwargs["aggGoal"] + except KeyError: + raise KeyError("aggGoal is not specified in config") + + + if self.agg_goal < 0: + self.agg_goal = 1 + + # #### CHANGES BASED OFF FEDBUFF FOR ASYNCFL + # Tracking selected ends to ensure selection correctness for + # each round (a trainer can participate only once per round). + self.all_selected = dict() + self.selected_ends = dict() + + # Tracks weight updates received from trainers and makes them + # available to select again + self.ordered_updates_recv_ends = list() + + # Tracks timeouted trainers and number of times it happened to + # a trainer + self.track_trainer_timeouts = dict() + + # Tracks trainers that were selected but left training in + # between + self.track_selected_trainers_which_left = dict() + self.check_three_state_avl = True + + # Track sliding window statistics for the selector + self._selector_stats = {} + for task in ["train", "eval"]: + self._selector_stats[task] = {"data": {}, "summary": {}} + for metric in ["util", "speed", "round"]: + for window in [50, 100, 200]: + key = f"{metric}_last_{window}" + self._selector_stats[task]["data"][key] = deque(maxlen=window) + + self._select_run_counter = 0 + + def select( + self, + ends: dict[str, End], + channel_props: dict[str, Scalar], + trainer_unavail_list: list, + task_to_perform: str = "train", + **kwargs, + ) -> SelectorReturnType: + """Return k number of ends from the given ends. + + NOTE: It incorporates the same send/recv mechanism from + fedbuff. [From fedbuff selector]: Select ends from the given + ends to meet concurrency level. This select method chooses + ends differently depending on what state a channel is in. In + 'send' state, it chooses ends that are not in + self.selected_ends. In 'recv' state, it chooses all ends from + self.selected_ends. Essentially, if an end is in + self.selected_ends, it means that we sent some message already + to that end. For such an end, we exclude it from send and + include it for recv in return. + """ + logger.info("calling async random select") + # Extract aggregator version and trainer version states for staleness tracking + agg_version_state = kwargs.get("agg_version_state") + trainer_version_states = kwargs.get("trainer_version_states") + logger.debug( + f"Aggregator version state (model_version, data_id, iteration_id): {agg_version_state}" + ) + logger.debug(f"Trainer version states: {trainer_version_states}") + + if task_to_perform == "train": + concurrency = min(len(ends), self.c) + + logger.info( + f"Task: {task_to_perform}, len(ends): {len(ends)}, c: {self.c}, chosen concurrency: {concurrency}" + ) + + if concurrency == 0: + logger.debug("ends is empty") + return {} + + if KEY_CH_STATE not in channel_props: + raise KeyError(f"channel property doesn't have {KEY_CH_STATE}") + + self.requester = channel_props[KEY_CH_SELECT_REQUESTER] + if self.requester not in self.selected_ends: + self.selected_ends[self.requester] = set() + + # default, availability unaware way of using ends + eligible_ends = ends + + # Make a filter of unavailable ends, update eligible_ends + # given trainer_unavail_list + if trainer_unavail_list != [] and trainer_unavail_list is not None: + # Updating passed ends and filtering out unavailable ones + # before passing + eligible_ends = { + end_id: end + for end_id, end in ends.items() + if end_id not in trainer_unavail_list + } + logger.debug( + f"Fedbuff select got non-empty trainer_unavail_list, " + f"populated eligible_ends: {eligible_ends}" + ) + + if channel_props[KEY_CH_STATE] == VAL_CH_STATE_SEND: + logger.debug( + f"Inside send state: aggregator version state (model_version, data_id, iteration_id): {agg_version_state}" + ) + logger.debug( + f"Inside send state: trainer version states: {trainer_version_states}" + ) + results = self._handle_send_state( + ends=eligible_ends, + concurrency=concurrency, + channel_props=channel_props, + trainer_unavail_list=trainer_unavail_list, + task_to_perform=task_to_perform, + agg_version_state=agg_version_state, + trainer_version_states=trainer_version_states, + ) + + if len(results) is not 0: + self._select_run_counter += 1 + + + elif channel_props[KEY_CH_STATE] == VAL_CH_STATE_RECV: + results = self._handle_recv_state(ends, concurrency) + + else: + state = channel_props[KEY_CH_STATE] + raise ValueError(f"unkown channel state: {state}") + + logger.debug( + f"requester: {self.requester}, selected ends: {self.selected_ends}" + ) + logger.debug( + f"channel state: {channel_props[KEY_CH_STATE]}, results: {results}" + ) + + return results + + def select_random(self, ends: dict[str, End], num_of_ends: int) -> dict[str, None]: + """Randomly select num_of_ends ends.""" + + selected_random_ends = set(random.sample(list(ends), num_of_ends)) + logger.debug(f"selected_random_ends: {selected_random_ends}") + + return {key: None for key in selected_random_ends} + + def _cleanup_provided_ends( + self, ends_to_cleanup: dict[str, End], ends: dict[str, End] + ): + """Clean-up a specific end so it becomes eligible for sampling again - reject stale updates in FwdLLM (async)""" + + selected_ends = self.selected_ends.get(self.requester, set()) + for end_id, _ in ends_to_cleanup.items(): + state = ends[end_id].get_property(KEY_END_STATE) + logger.info(f"Cleaning end {end_id}, current state: {state}") + + # reset only if it's in received state + if state == VAL_END_STATE_RECVD: + ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_NONE) + logger.debug( + f"Setting {end_id} state to {VAL_END_STATE_NONE}, " + f"and" + f" removing from selected_ends and all_selected" + ) + + # remove from active selection tracking + if end_id in selected_ends: + selected_ends.remove(end_id) + logger.debug(f"Removed {end_id} from selected_ends") + + if end_id in self.all_selected: + del self.all_selected[end_id] + logger.debug(f"Removed {end_id} from all_selected") + + # update the mapping back + self.selected_ends[self.requester] = selected_ends + logger.info( + f"Cleanup complete. Freed [{ends}] end(s) for resampling; state set to {VAL_END_STATE_NONE}." + ) + + # #### CHANGES BASED OFF FEDBUFF FOR ASYNCFL + def _cleanup_recvd_ends(self, ends: dict[str, End]): + """Clean up ends whose a message was received, from selected + ends. + + Note: It sets the end state to none which makes it eligible to + be sampled again. This can cause problems if sampled in the + same round. Thus, for aggregator, the _cleanup_recvd_ends + should be triggered only after aggregation of weights succeeds + on meeting agg_goal.""" + logger.debug( + f"clean up recvd ends. selected_ends: {self.selected_ends}, ends: {ends.keys()}" + ) + + selected_ends = self.selected_ends[self.requester] + logger.debug( + f"self.requester: {self.requester} and selected_ends: " + f"{selected_ends} before processing" + ) + + num_ends_to_remove = min(len(self.ordered_updates_recv_ends), self.agg_goal) + if num_ends_to_remove != 0: + ends_to_remove = self.ordered_updates_recv_ends[:num_ends_to_remove] + logger.debug( + f"Will remove these ends from " + f"ordered_updates_recv_ends: {ends_to_remove}" + f" and selected_ends and all_selected" + ) + + # removing the first agg-goal number of ends to free them + # to participate in the next round + self.ordered_updates_recv_ends = self.ordered_updates_recv_ends[ + num_ends_to_remove: + ] + logger.debug( + f"self.ordered_updates_recv_ends after removing first " + f"num_ends_to_remove: {num_ends_to_remove} " + f"elements: {self.ordered_updates_recv_ends}" + ) + + logger.debug( + f"Ends to remove based on trainer updates received: {ends_to_remove}" + ) + + logger.debug(f"All ends to remove (train ): {ends_to_remove}") + + for end_id in ends_to_remove: + if end_id not in ends: + # something happened to end of end_id (e.g., + # connection loss) let's remove it from + # selected_ends + logger.debug( + f"no end id {end_id} in ends, removing " + f"from selected_ends and all_selected" + ) + # NOTE: it is not a guarantee that selected_ends + # will still contain the end_id. Thats because it + # might have got disconnected/ rejoined in the + # middle of a round + if end_id in selected_ends: + selected_ends.remove(end_id) + logger.debug( + f"No end id {end_id} in ends, removed from " + f"selected_ends: " + f"{selected_ends}" + ) + if end_id in self.all_selected: + del self.all_selected[end_id] + logger.debug( + f"No end id {end_id} in ends, removed from " + f"self.all_selected: {self.all_selected}" + ) + else: + state = ends[end_id].get_property(KEY_END_STATE) + logger.debug( + f"End_id {end_id} found in selected_ends in state: {state}, " + f"selected_ends: {selected_ends} and self.all_selected: " + f"{self.all_selected}" + ) + if state == VAL_END_STATE_RECVD: + ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_NONE) + logger.debug( + f"Setting {end_id} state to {VAL_END_STATE_NONE}, " + f"and" + f" removing from selected_ends and all_selected" + ) + if end_id in selected_ends: + selected_ends.remove(end_id) + logger.debug( + f"FOUND end id {end_id} in state: {state}.. " + f"removed from " + f"selected_ends: {selected_ends}" + ) + if end_id in self.all_selected: + del self.all_selected[end_id] + logger.debug( + f"FOUND end id {end_id} in state: {state}.. " + f"removed from " + f"self.all_selected: " + f"{self.all_selected}" + ) + elif state == VAL_END_STATE_NONE: + logger.debug( + f"Found end {end_id} in state {VAL_END_STATE_NONE}. Might have " + f"left/rejoined. Need to remove it from " + f"selected_ends and self.all_selected " + f"if it was selected" + ) + if end_id in selected_ends: + selected_ends.remove(end_id) + logger.debug( + f"FOUND end id {end_id} in state: {state}.. " + f"removed from " + f"selected_ends: {selected_ends}" + ) + if end_id in self.all_selected: + del self.all_selected[end_id] + logger.debug( + f"FOUND end id {end_id} in state: {state}.. " + f"removed from " + f"self.all_selected: " + f"{self.all_selected} too" + ) + else: + logger.debug( + f"FOUND end id {end_id} in state: {state}. " + f"Not doing anything" + ) + else: + logger.debug("No ends to remove so far") + + def _cleanup_removed_ends(self, end_id): + logger.debug( + f"Going to cleanup selector state for " + f"end_id {end_id} since it has left the channel" + ) + if (end_id in self.all_selected) and ( + end_id not in self.ordered_updates_recv_ends + ): + # remove end from all_selected if we havent got an update + # from it yet. It would have flushed the agg-weights after + # initiating channel.leave(). + logger.debug( + f"Removing end_id {end_id} from all_selected" + f" since no update received before it left the channel." + ) + selected_ends = self.selected_ends[self.requester] + if end_id in selected_ends: + selected_ends.remove(end_id) + logger.debug(f"Also removing end_id {end_id} from selected_ends") + self.selected_ends[self.requester] = selected_ends + + # Track trainers that were sent weights but dropped off + # before sending back an update + if end_id in self.track_selected_trainers_which_left: + self.track_selected_trainers_which_left[end_id] += 1 + else: + self.track_selected_trainers_which_left[end_id] = 1 + + total_trainers_dropped_off = 0 + for k, v in self.track_selected_trainers_which_left.items(): + total_trainers_dropped_off += v + + logger.debug( + f"Trainer: {end_id} with count " + f"{self.track_selected_trainers_which_left[end_id]}, left " + f"before returning update. " + f"total_trainers_dropped_off: {total_trainers_dropped_off} " + f"self.track_selected_trainers_which_left: " + f"{self.track_selected_trainers_which_left}" + ) + if end_id in self.all_selected.keys(): + del self.all_selected[end_id] + elif (end_id in self.all_selected) and ( + end_id in self.ordered_updates_recv_ends + ): + # Dont remove it if it was in all_selected and we have got + # an update from it before it did channel.leave(). It has + # completed its participation for this round. + logger.debug( + f"Update was alreacy received from {end_id} before it left " + f"the channel. Not deleting from all_ends now." + ) + else: + logger.warning( + f"End_id {end_id} remove check from all_selected failed. " + f"Need to check" + ) + + def _handle_send_state( + self, + ends: dict[str, End], + concurrency: int, + channel_props: dict[str, Scalar], + trainer_unavail_list: list = None, + task_to_perform: str = "train", + agg_version_state=None, # (model_version, data_id, iteration_id) + trainer_version_states: dict[str, tuple[int, int, int]] = None, + ) -> SelectorReturnType: + selected_ends = self.selected_ends[self.requester] + logger.debug(f"Inside handle send state: aggregator version state {agg_version_state}") + logger.debug( + f"Inside handle send state: trainer version states {trainer_version_states}" + ) + # Check for invalid selections and remove them + for end_id in list(selected_ends): + if end_id not in ends: + # something happened to end of end_id (e.g., + # connection loss) let's remove it from selected_ends + # so that you can fill that spot with another trainer + logger.info( + f"Removing invalid prior selection! " + f"No end id {end_id} in ends, " + f"removing from selected_ends. " + f"NOT from all_selected right now " + f"cause aggregation for that " + f"round hasnt completed yet" + ) + selected_ends.remove(end_id) + # NOTE: Not removing end_id from all_selected since it + # might have already participated in the same round + # (if it is still in all_ends) + + logger.debug(f"Current selected_ends: {selected_ends}") + + extra = max(0, concurrency - len(selected_ends)) + + logger.debug( + f"c: {concurrency}, " + f"len(selected_ends): {len(selected_ends)}, extra: {extra}, selected_ends: {selected_ends}," + f"len(ends): {len(ends)}" + ) + candidates = [] + + if extra == 0: + logger.debug(f"extra: {extra}, nothing to select") + return {} + + round = channel_props["round"] if "round" in channel_props else 0 + logger.debug(f"let's select {extra} ends for round {round}") + + + # Invalidate previous all_selected entry if you don't get an + # update in UPDATE_TIMEOUT_WAIT_S. The client might have + # dropped the message with transient unavailability. + + curr_all_selected_ends = list(self.all_selected.keys()) + for end in curr_all_selected_ends: + current_time_s = time.time() + if end in self.all_selected.keys(): + # Check again to avoid possible case of race condition + # when all_selected has been updated from another + # thread + trainer_weight_send_timestamp_s = self.all_selected[end] + if ( + trainer_weight_send_timestamp_s + < (current_time_s - SEND_TIMEOUT_WAIT_S) + ) and (end not in self.ordered_updates_recv_ends): + # trainer hasn't returned with an update in + # SEND_TIMEOUT_WAIT_S delete it from + # self.all_selected so that it is eligible to be + # sampled again + logger.info( + f"Removing end {end} from self.all_selected " + f"since havent " + f"got its update in {SEND_TIMEOUT_WAIT_S}. " + f"Last weight send timestamp was: {trainer_weight_send_timestamp_s}" + ) + + if end in self.track_trainer_timeouts: + self.track_trainer_timeouts[end] += 1 + else: + self.track_trainer_timeouts[end] = 1 + + # Capture total time spent in timeouts + num_of_timeouts_occured = 0 + for k, v in self.track_trainer_timeouts.items(): + num_of_timeouts_occured += v + + total_time_spent_timeouts_s = ( + num_of_timeouts_occured * SEND_TIMEOUT_WAIT_S + ) + + logger.debug( + f"Timeout for trainer: {end} with count " + f"{self.track_trainer_timeouts[end]}. " + f"num_of_timeouts_occured : " + f"{num_of_timeouts_occured}, " + f"total_time_spent_timeouts_s: " + f"{total_time_spent_timeouts_s}, " + f"Timeout frequency: {self.track_trainer_timeouts}" + ) + + # delete the end from self.all_selected + if end in self.all_selected.keys(): + del self.all_selected[end] + + # TODO: (DG) Add code to allow only those ends (not in + # all_selected) to be passed. filtered_ends consists of ends + # that are not in all_selected and can be picked in this round + # i.e. avoids repeating a trainer in the same round + filtered_ends = dict() + + # track the ends that are eligible vs ineligible based on + # their state + count_avl_train = 0 + count_ineligible = 0 + + # Check the eligible set first. Out of the ends, how many are + # not in all_selected? Only those are eligible since the rest + # have weights already sent to them for either train/eval + # task. + count_eligible_set_to_check = [ + end for end in ends if end not in self.all_selected + ] + logger.debug( + f"Before creating filtered_ends. count_eligible_set_to_check: {len(count_eligible_set_to_check)} from total {len(ends)} ends." + ) + + for end_id in ends: + if end_id not in self.all_selected.keys(): + logger.debug( + f"Creating filtered ends. Checking end id {end_id}, avl_state = {ends[end_id].get_property(PROP_AVL_STATE)}" + ) + + # If check_three_state_avl=False, no more checks, + # directly add end to filtered_ends + + # If check_three_state_avl=True, filtered ends needs + # to be populated based on the following conditions: + # For task_to_perform=train, eligible ends are in + # states {avl_train, None} For task_to_perform=eval, + # eligible ends are in states {avl_train, avl_eval + # None} + + curr_end_id_avl_state = ends[end_id].get_property(PROP_AVL_STATE) + # Even if client notify is not enabled, this logic + # would work since curr_end_id_avl_state = None + if task_to_perform == "train" and ( + curr_end_id_avl_state in (TrainerAvailState.AVL_TRAIN.value, None) + ): + filtered_ends[end_id] = ends[end_id] + logger.debug( + f"Adding end {end_id} to filtered ends. Three_state_avl_check is True, task_to_perform: {task_to_perform} in state: {ends[end_id].get_property(PROP_AVL_STATE)}" + ) + count_avl_train += 1 + + else: + logger.debug( + f"Not adding end {end_id} to filtered ends since required for task{task_to_perform}, " + f"but was in state {curr_end_id_avl_state}. Not eligible." + ) + count_ineligible += 1 + + logger.info( + f"Filtered ends created. count_avl_train: {count_avl_train}, count_ineligible: {count_ineligible}" + ) + + if agg_version_state is not None and trainer_version_states is not None: + curr_model_version, curr_data_id, curr_iteration_id = agg_version_state + logger.info(f"Trainer version states: {trainer_version_states}") + logger.info(f"Handle send state: aggregator version state {agg_version_state}") + # Filter out trainers who already received this same triplet + eligible_filtered_ends = {} + logger.debug(f"Filtered ends: {filtered_ends.items()}") + for end_id, end in filtered_ends.items(): + prev_state = trainer_version_states.get(end_id) + logger.debug(f"Prev version state: {prev_state}") + + if prev_state != agg_version_state: + logger.debug(f"Not skipping trainer: {end_id}") + eligible_filtered_ends[end_id] = end + else: + logger.info( + f"Skipping trainer: {end_id} already has same " + f"(model_version={curr_model_version}, " + f"iteration_id={curr_iteration_id}, data_id={curr_data_id})" + ) + filtered_ends = eligible_filtered_ends + + # extra informs about maximum possible available ends that can + # be picked to meet the concurrency target. But it might count + # infeasible ends too (ends that have already particpated in + # the round). It is essentially a superset of feasible and + # infeasible. Maximum feasible comes from filtered_ends. We + # define and henceforth use feasible_extra to (i) use extra's + # knowledge of how many to pick and (ii) use filtered_ends + # knowledge of what is feasible to pick Eg scenarios: + # (extra=1, filtered=3), (extra=2, filtered=2), (extra=3, + # filtered=1) + feasible_extra = min(extra, len(filtered_ends)) + logger.info( + f"desired extra: {extra}, len(filtered_ends): {len(filtered_ends)}, feasible_extra: {feasible_extra}" + ) + + # Early exit if filtered_ends is none (can happen when all + # ends available are less than concurrency requirement) + if len(filtered_ends) == 0: + logger.info( + f"len(filtered_ends): {len(filtered_ends)}, hence returning " + f"with empty candidates" + ) + return {} + + # This is only for train as of now, can be extended for eval in the future + if task_to_perform == "train": + # Make a filter of blocklist ends + # blocklist_end_ids = self.find_blocklists(filtered_ends) + + if trainer_unavail_list != []: + logger.info( + "### Oort select got non-empty trainer_unavail_list, will " + "remove unavail trainers from round" + ) + + self.round = round + + logger.info( + f"Round: {self.round}, will sample feasible_extra: " + f"{feasible_extra} from len(filtered_ends): " + f"{len(filtered_ends)}" + ) + candidates_dict = self.select_random( + filtered_ends, num_of_ends=feasible_extra + ) + # Invoke process_chosen_candidate_dict(). It will + # appropriately add candidates to selected_ends and + # all_selected + self.process_chosen_candidate_dict( + candidates_dict=candidates_dict, selected_ends=selected_ends + ) + + logger.info( + f"handle_send_state returning " + f"candidates_dict: {candidates_dict}" + ) + + return candidates_dict + + + def _handle_recv_state( + self, ends: dict[str, End], concurrency: int + ) -> SelectorReturnType: + selected_ends = self.selected_ends[self.requester] + + # from the selected ends, remove those that are in recv state + # already This is done to avoid waiting on trainers that you + # have already heard from. If selected ends is empty, get() + # will proceed and wait on distribute_weights before running + # again. Thus, it avoids stalling and ensures progress + for end_id in list(selected_ends): + # trainer might have become unavailable, check if it is + # still available first + if end_id in ends: + curr_end_state = ends[end_id].get_property(KEY_END_STATE) + if curr_end_state == VAL_END_STATE_RECVD: + selected_ends.remove(end_id) + logger.debug( + f"Removed end_id {end_id} from selected ends since it " + f"was already in {curr_end_state} state" + ) + else: + + logger.debug( + f"Tried to check state of end {end_id} but it is no " + f"longer in self._ends" + ) + + if len(selected_ends) == 0: + logger.debug(f"len(selected_ends)=0, let's select {concurrency} ends") + + candidates = dict() + for end_id, end in ends.items(): + curr_end_state = end.get_property(KEY_END_STATE) + # candidates[end_id] = end + if end_id not in self.all_selected.keys(): + if curr_end_state != VAL_END_STATE_NONE: + logging.info( + f"end_id {end_id} not in all_selected and in state: {curr_end_state}, adding " + f"to candidates: key {end_id}, val: {end}" + ) + candidates[end_id] = end + else: + logging.debug( + f"end_id {end_id} not in all_selected but in state: {curr_end_state}, not adding " + f"to candidates" + ) + + cc = min(len(candidates), concurrency) + logger.debug( + f"Will pick cc: {cc} as min(candidates,concurrency) " + f"from candidates: {candidates}" + ) + selected_ends = set(random.sample(list(candidates), cc)) + + self.selected_ends[self.requester] = selected_ends + logger.debug( + f"self.selected_ends[req]: {self.selected_ends[self.requester]}" + ) + + for selected_end in selected_ends: + # Add to all_selected. {key: end, val: TS epoch (s)} + self.all_selected[selected_end] = time.time() + logging.debug( + f"self.all_selected {self.all_selected} after combining with " + f"selected_ends {selected_ends}" + ) + + logger.debug(f"handle_recv_state returning selected_ends: {selected_ends}") + + return {key: None for key in selected_ends} + + def reset_end_state_to_none(self, ends: dict[str, End], end_id: str) -> None: + """Reset's the state of end_id from send/recv to none""" + if end_id in ends.keys(): + curr_end_state = ends[end_id].get_property(KEY_END_STATE) + ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_NONE) + new_end_state = ends[end_id].get_property(KEY_END_STATE) + logger.debug( + f"Successfully reset state for end " + f"{end_id} from previous: {curr_end_state} to " + f"current: {new_end_state}" + ) + else: + logger.debug( + f"Attempted to reset end {end_id} state " f"but it wasnt in ends" + ) + + def remove_from_selected_ends(self, ends: dict[str, End], end_id: str) -> None: + """Remove an end from selected ends""" + selected_ends = self.selected_ends[self.requester] + if end_id in ends.keys(): + if end_id in selected_ends: + logger.debug( + f"Going to remove end_id {end_id} from selected_ends " + f"{selected_ends}" + ) + selected_ends.remove(end_id) + self.selected_ends[self.requester] = selected_ends + logger.debug( + f"self.selected_ends: {self.selected_ends} after " + f"removing end_id: {end_id}" + ) + else: + logger.debug( + f"Attempted to remove end {end_id} from " + f"self.selected_ends {self.selected_ends}, but it wasnt present" + ) + else: + logger.debug( + f"Attempted to remove end {end_id} from " + f"self.selected_ends {self.selected_ends}, but it wasnt in ends" + ) + + def process_chosen_candidate_dict( + self, + candidates_dict: dict[str, None], + selected_ends: set[str], + ): + candidates = list(candidates_dict.keys()) + logger.debug( + f"Got candidates_dict as {candidates_dict} after " f"select_random" + ) + logger.debug(f"candidates: {candidates}") + + # add candidates to selected ends + selected_ends = selected_ends.union(candidates) + self.selected_ends[self.requester] = selected_ends + logger.debug( + f"added candidates to selected_ends: {candidates}, selected_ends: " + f"{selected_ends}, " + f"self.selected_ends[req]: {self.selected_ends[self.requester]}" + ) + + for candidate_end in candidates: + # Add to all_selected. {key: end, val: TS epoch (s)} + self.all_selected[candidate_end] = time.time() + logging.debug( + f"self.all_selected {self.all_selected} after combining" + f" with candidates {candidates}" + ) + + logger.debug("finished processing candidates_dict") diff --git a/lib/python/flame/selectors.py b/lib/python/flame/selectors.py index 16dd97386..660a24f1f 100644 --- a/lib/python/flame/selectors.py +++ b/lib/python/flame/selectors.py @@ -22,6 +22,7 @@ from .selector.fedbuff import FedBuffSelector from .selector.oort import OortSelector from .selector.random import RandomSelector +from .selector.async_random import AsyncRandomSelector class SelectorProvider(ObjectFactory): @@ -38,3 +39,4 @@ def get(self, selector_name, **kwargs): selector_provider.register(SelectorType.FEDBUFF, FedBuffSelector) selector_provider.register(SelectorType.OORT, OortSelector) selector_provider.register(SelectorType.ASYNC_OORT, AsyncOortSelector) +selector_provider.register(SelectorType.ASYNC_RANDOM, AsyncRandomSelector) diff --git a/scripts/syn_trace_gen.py b/scripts/syn_trace_gen.py new file mode 100644 index 000000000..be7e6be61 --- /dev/null +++ b/scripts/syn_trace_gen.py @@ -0,0 +1,62 @@ +import random +import json + +def generate_mobiperf_traces(num_trainers=100, duration_sec=3600, + p_unavl=0.10, p_eval=0.20, p_train=0.70): + """ + Generates synthetic state traces for mobile trainers. + """ + states = ['UN_AVL', 'AVL_EVAL', 'AVL_TRAIN'] + target_dist = [p_unavl, p_eval, p_train] + + # Configuration: Average time (seconds) spent in a state before switching + # Adjusting these changes the 'frequency' of churn + avg_stay_duration = 300 + + all_traces = {} + + for i in range(num_trainers): + current_time = 0 + trace = [] + + # Initial state based on target distribution + current_state = random.choices(states, weights=target_dist)[0] + trace.append((current_time, current_state)) + + while current_time < duration_sec: + # 1. Determine how long to stay in current state (Exponential distribution) + stay_duration = int(random.expovariate(1.0 / avg_stay_duration)) + if stay_duration < 1: stay_duration = 1 + + current_time += stay_duration + if current_time >= duration_sec: + break + + # 2. Transition to a NEW state + # To maintain steady state, we sample from the target distribution excluding current state + remaining_states = [s for s in states if s != current_state] + remaining_weights = [target_dist[states.index(s)] for s in remaining_states] + + current_state = random.choices(remaining_states, weights=remaining_weights)[0] + trace.append((current_time, current_state)) + + all_traces[f"trainer_{i}"] = str(trace) + + return all_traces + +# --- Configuration --- +TRAINERS = 100 +MINUTES = 60 +DURATION = MINUTES * 60 # Convert to seconds + +# Targets: 10% UN_AVL, 20% AVL_EVAL, 70% AVL_TRAIN +traces = generate_mobiperf_traces( + num_trainers=TRAINERS, + duration_sec=DURATION, + p_unavl=0.10, + p_eval=0.20, + p_train=0.70 +) + +# Output example +print(json.dumps({"synthetic_trace": traces["trainer_0"]}, indent=4)) \ No newline at end of file diff --git a/scripts/syn_trace_gen_improved.py b/scripts/syn_trace_gen_improved.py new file mode 100644 index 000000000..8df143ea7 --- /dev/null +++ b/scripts/syn_trace_gen_improved.py @@ -0,0 +1,231 @@ +import numpy as np +import json +import os +import glob +import re +import matplotlib.pyplot as plt + +def batch_inject_and_plot(folder_path='.', max_trainers=100, train_p=0.90, eval_p=0.05, + unavail_p=0.05, total_minutes=1440, interval=10, stickiness=0.95): + states = ['UN_AVL', 'AVL_EVAL', 'AVL_TRAIN'] + target_dist = np.array([unavail_p, eval_p, train_p]) + total_rounds = total_minutes // interval + + # 1. ROBUST TRANSITION MATRIX (Detailed Balance) + n = len(target_dist) + trans_matrix = np.zeros((n, n)) + + # We define a base transition rate 'alpha' + # High stickiness means alpha is small + alpha = 1.0 - stickiness + + for i in range(n): + for j in range(n): + if i != j: + # The probability of moving i -> j depends on the target density of j + # This ensures the Markov chain is 'pulled' toward the target distribution + trans_matrix[i, j] = alpha * target_dist[j] + + # The diagonal (staying put) is 1 minus the sum of moving elsewhere + trans_matrix[i, i] = 1.0 - np.sum(trans_matrix[i, :]) + + # 2. File and Key Setup + key_name = f"avl_events_syn_train_{int(train_p*100)}_eval_{int(eval_p*100)}_unavail_{int(unavail_p*100)}" + search_pattern = os.path.join(folder_path, "trainer_*.json") + files = glob.glob(search_pattern) + files.sort(key=lambda f: int(re.sub('\D', '', os.path.basename(f)) or 0)) + files_to_process = files[:max_trainers] + + if not files_to_process: + print(f"No files found in {folder_path}") + return + + all_states_history = np.zeros((total_rounds, len(files_to_process))) + + for node_idx, file_path in enumerate(files_to_process): + with open(file_path, 'r') as f: + data = json.load(f) + + node_trace = [] + # INITIALIZE BASED ON TARGET + curr_state_idx = np.random.choice([0, 1, 2], p=target_dist) + all_states_history[0, node_idx] = curr_state_idx + last_state_name = states[curr_state_idx] + node_trace.append((0, last_state_name)) + + for r in range(1, total_rounds): + curr_state_idx = np.random.choice([0, 1, 2], p=trans_matrix[curr_state_idx]) + all_states_history[r, node_idx] = curr_state_idx + curr_state_name = states[curr_state_idx] + + if curr_state_name != last_state_name: + node_trace.append((r * interval * 60, curr_state_name)) + last_state_name = curr_state_name + + if "hyperparameters" in data: + data["hyperparameters"][key_name] = str(node_trace) + with open(file_path, 'w') as f: + json.dump(data, f, indent=4) + + # 3. Updated Two-Line Validation Plot + time_axis = np.arange(total_rounds) * interval * 60 + + # Calculate Residency Percentages + train_pct = np.mean(all_states_history == 2, axis=1) * 100 + # Total Available = Anything NOT Unavail (state index 0) + combined_pct = np.mean(all_states_history > 0, axis=1) * 100 + + plt.figure(figsize=(14, 6)) + + # Actual Data Lines + plt.plot(time_axis, combined_pct, label='Total Available (TRAIN + EVAL)', + color='green', linewidth=1.8, alpha=0.9) + plt.plot(time_axis, train_pct, label='Actual AVL_TRAIN', + color='blue', linewidth=1.2, alpha=0.7) + + # Expectation Guide Lines + expected_total = (train_p + eval_p) * 100 + expected_train = train_p * 100 + + plt.axhline(y=expected_total, color='green', linestyle='--', + alpha=0.4, label=f'Total Expectation ({int(expected_total)}%)') + plt.axhline(y=expected_train, color='blue', linestyle=':', + alpha=0.5, label=f'Train Expectation ({int(expected_train)}%)') + + plt.title(f"System Availability: {int(train_p*100)}/{int(eval_p*100)}/{int(unavail_p*100)} (Stickiness {stickiness})") + plt.ylabel("% of Total Trainers") + plt.xlabel("Time (seconds)") + plt.legend(loc='upper right', bbox_to_anchor=(1.25, 1)) + plt.grid(True, linestyle=':', alpha=0.6) + plt.ylim(0, 105) + + plt.tight_layout() + filename = f'availability_two_lines_{int(train_p*100)}_{int(eval_p*100)}_{int(unavail_p*100)}_stickiness_{stickiness}.png' + plt.savefig(filename) + plt.show() + + return all_states_history + + +import pandas as pd +import seaborn as sns + +def plot_trainer_distribution(history_matrix, train_p, eval_p, + unavail_p, stickiness, states=['UN_AVL', 'AVL_EVAL', 'AVL_TRAIN']): + """ + history_matrix: (total_rounds, num_trainers) + """ + total_rounds = history_matrix.shape[0] + num_trainers = history_matrix.shape[1] + + # 1. Calculate the % of time each trainer spent in each state + # Result: A list of 100 values for each state + trainer_stats = { + 'UN_AVL': np.mean(history_matrix == 0, axis=0) * 100, + 'AVL_EVAL': np.mean(history_matrix == 1, axis=0) * 100, + 'AVL_TRAIN': np.mean(history_matrix == 2, axis=0) * 100 + } + + # 2. Plotting + plt.figure(figsize=(10, 6)) + colors = {'UN_AVL': 'red', 'AVL_EVAL': 'blue', 'AVL_TRAIN': 'green'} + + for state in states: + # We use a histogram with a Kernel Density Estimate (KDE) for smoothness + sns.histplot(trainer_stats[state], bins=20, kde=True, + color=colors[state], label=state, alpha=0.4) + + plt.title(f"Distribution of State Residency across {num_trainers} Trainers") + plt.xlabel("Percentage of Total Simulation Time (%)") + plt.ylabel("Number of Trainers") + plt.legend() + plt.grid(axis='y', alpha=0.3) + plt.tight_layout() + plt.savefig(f'trainer_state_distribution_{int(train_p*100)}_{int(eval_p*100)}_{int(unavail_p*100)}_stickiness_{stickiness}.png') + plt.show() + +def plot_stacked_trainer_composition(history_matrix, train_p, eval_p, + unavail_p, stickiness, max_trainers=100): + """ + history_matrix: (total_rounds, num_trainers) + """ + # 1. Calculate residency percentages + # We transpose so we have (num_trainers, 3) + unavail_pct = np.mean(history_matrix == 0, axis=0) * 100 + eval_pct = np.mean(history_matrix == 1, axis=0) * 100 + train_pct = np.mean(history_matrix == 2, axis=0) * 100 + + num_trainers = history_matrix.shape[1] + trainer_indices = np.arange(1, num_trainers + 1) + + plt.figure(figsize=(15, 6)) + + # 2. Plotting the stacks + # Bottom: TRAIN + plt.bar(trainer_indices, train_pct, color='green', label='AVL_TRAIN', alpha=0.8) + + # Middle: EVAL (starts at the top of TRAIN) + plt.bar(trainer_indices, eval_pct, bottom=train_pct, + color='blue', label='AVL_EVAL', alpha=0.8) + + # Top: UN_AVL (starts at the top of TRAIN + EVAL) + plt.bar(trainer_indices, unavail_pct, bottom=train_pct + eval_pct, + color='red', label='UN_AVL', alpha=0.8) + + # 3. Formatting + plt.axhline(y=90, color='white', linestyle='--', linewidth=1, alpha=0.5) + plt.xlabel("Trainer ID") + plt.ylabel("Percentage of Total Time (%)") + plt.title(f"Composition of State Residency per Trainer (Total: {num_trainers})") + plt.xlim(0, num_trainers + 1) + plt.ylim(0, 100) + plt.legend(loc='upper right', bbox_to_anchor=(1.12, 1)) + + plt.tight_layout() + plt.savefig(f'trainer_stacked_composition_{int(train_p*100)}_{int(eval_p*100)}_{int(unavail_p*100)}_stickiness_{stickiness}.png') + plt.show() + + +# # Setup 1 +# train_p=0.90 +# eval_p=0.10 +# unavail_p=0.0 +# stickiness=0.40 + +# # Setup 2 +# train_p=0.50 +# eval_p=0.30 +# unavail_p=0.20 +# stickiness=0.40 + +# Setup 3 +train_p=1.0 +eval_p=0.0 +unavail_p=0.0 +stickiness=0.40 + +# # For validation only +# train_p=0.50 +# eval_p=0.30 +# unavail_p=0.20 +# stickiness=0.90 + + +# Run the corrected version +all_states_history = batch_inject_and_plot( + folder_path='/home/dgarg39/aish_test/flame/lib/python/examples/fwdllm/expts/run_tc_expts/json_scripts/', + max_trainers=100, + train_p=train_p, + eval_p=eval_p, + unavail_p=unavail_p, + stickiness=stickiness +) + +plot_trainer_distribution(all_states_history,train_p=train_p, + eval_p=eval_p, + unavail_p=unavail_p, + stickiness=stickiness ) +plot_stacked_trainer_composition(all_states_history,train_p=train_p, + eval_p=eval_p, + unavail_p=unavail_p, + stickiness=stickiness ) \ No newline at end of file