diff --git a/pytorch_impl/applications/LEARN/demo.py b/pytorch_impl/applications/LEARN/demo.py index 063a4fb..9cd11b2 100644 --- a/pytorch_impl/applications/LEARN/demo.py +++ b/pytorch_impl/applications/LEARN/demo.py @@ -109,6 +109,7 @@ def node( non_iid, q, port, + sync, ): logger.debug("**** SETUP AT NODE %s ***", rank) logger.debug("Number of nodes: %d", n) @@ -123,6 +124,8 @@ def node( logger.debug("Assume Non-iid data? %s", non_iid) logger.debug("------------------------------------") + sync.put(True) + gar = aggregators.gars.get(gar) torch.manual_seed(1234) # For reproducibility @@ -244,6 +247,7 @@ def node( class Trainer: TIMEOUT_PROGRESS_SEC = 1 * 60 TIMEOUT_TERMINATE_SEC = 10 + MAX_ATTEMPTS_STARTUP = 5 def __init__(self, n, f, gar, port): if n < 1 or n > 10: @@ -272,35 +276,59 @@ def train(self): self.status = {rank: -1 for rank in range(self.n)} q = mp.Queue() - - ps = [] - for rank in range(self.n): - logger.info("Starting process with rank %d", rank) - p = mp.Process( - target=node, - kwargs=dict( - rank=rank, - is_byzantine=(rank < self.f), - world_size=self.n, - batch=batch_size, - model=model, - dataset=dataset, - loss="binary-cross-entropy", - nb_epochs=nb_epochs, - n=self.n, - f=self.f, - gar=self.gar, - optimizer="rmsprop", - opt_args={"lr": 0.001, "momentum": 0.9, "weight_decay": 0.0005}, - non_iid=False, - q=q, - port=self.port, - ), - ) - p.start() - ps.append(p) - - logger.info("Waiting for results") + sync = mp.Queue() + + nb_attempt = 0 + while True: + logger.info("Starting processes (attempt # %d)", nb_attempt + 1) + ps = [] + for rank in range(self.n): + logger.info("Starting process with rank %d", rank) + p = mp.Process( + target=node, + kwargs=dict( + rank=rank, + is_byzantine=(rank < self.f), + world_size=self.n, + batch=batch_size, + model=model, + dataset=dataset, + loss="binary-cross-entropy", + nb_epochs=nb_epochs, + n=self.n, + f=self.f, + gar=self.gar, + optimizer="rmsprop", + opt_args={"lr": 0.001, "momentum": 0.9, "weight_decay": 0.0005}, + non_iid=False, + q=q, + port=self.port, + sync=sync, + ), + ) + p.start() + ps.append(p) + + # Sometimes a process fails to start properly for an unknown + # reason, even though `p.is_alive` and other signs look normal. + # This is an attempt to detect the issue early and retry. + try: + sync.get(timeout=10) + for _ in range(len(ps) - 1): + sync.get(timeout=1) + break + except queue.Empty as exc: + # Try to cleanup + for p in ps: + p.kill() + + nb_attempt += 1 + if nb_attempt == self.MAX_ATTEMPTS_STARTUP: + raise Exception("Timeout while syncing processes") from exc + + logger.error("Timeout while syncinc processes, restarting") + + logger.info("All processes synchronized -- waiting for results") try: acc = []