From 0785eae552de0dcfc13d54af73cac78a5a56b4b1 Mon Sep 17 00:00:00 2001 From: yoshikisd Date: Fri, 1 Aug 2025 21:37:05 +0000 Subject: [PATCH 01/13] Created Reconstructors class to replace optimization methods in CDIModel --- src/cdtools/__init__.py | 4 +- src/cdtools/reconstructors/__init__.py | 16 ++ src/cdtools/reconstructors/adam.py | 160 ++++++++++++ src/cdtools/reconstructors/base.py | 325 +++++++++++++++++++++++++ src/cdtools/reconstructors/lbfgs.py | 134 ++++++++++ src/cdtools/reconstructors/sgd.py | 157 ++++++++++++ 6 files changed, 794 insertions(+), 2 deletions(-) create mode 100644 src/cdtools/reconstructors/__init__.py create mode 100644 src/cdtools/reconstructors/adam.py create mode 100644 src/cdtools/reconstructors/base.py create mode 100644 src/cdtools/reconstructors/lbfgs.py create mode 100644 src/cdtools/reconstructors/sgd.py diff --git a/src/cdtools/__init__.py b/src/cdtools/__init__.py index 96842075..91322091 100644 --- a/src/cdtools/__init__.py +++ b/src/cdtools/__init__.py @@ -4,9 +4,9 @@ warnings.filterwarnings("ignore", message='To copy construct from a tensor, ') -__all__ = ['tools', 'datasets', 'models'] +__all__ = ['tools', 'datasets', 'models', 'reconstructors'] from cdtools import tools from cdtools import datasets from cdtools import models - +from cdtools import reconstructors diff --git a/src/cdtools/reconstructors/__init__.py b/src/cdtools/reconstructors/__init__.py new file mode 100644 index 00000000..84b96ab4 --- /dev/null +++ b/src/cdtools/reconstructors/__init__.py @@ -0,0 +1,16 @@ +"""This module contains optimizers for performing reconstructions + +""" + +# We define __all__ to be sure that import * only imports what we want +__all__ = [ + 'Reconstructor', + 'Adam', + 'LBFGS', + 'SGD' +] + +from cdtools.reconstructors.base import Reconstructor +from cdtools.reconstructors.adam import Adam +from cdtools.reconstructors.lbfgs import LBFGS +from cdtools.reconstructors.sgd import SGD diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py new file mode 100644 index 00000000..5a489a06 --- /dev/null +++ b/src/cdtools/reconstructors/adam.py @@ -0,0 +1,160 @@ +"""This module contains the Adam Reconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +the Adam optimizer. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +import torch as t +from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset +from cdtools.models import CDIModel +from typing import Tuple, List, Union +from cdtools.reconstructors import Reconstructor + +__all__ = ['Adam'] + + +class Adam(Reconstructor): + """ + The Adam Reconstructor subclass handles the optimization ('reconstruction') + of ptychographic models and datasets using the Adam optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.Adam` to perform + optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during the + `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + super().__init__(model, dataset, subset) + + # Define the optimizer for use in this subclass + self.optimizer = t.optim.Adam(self.model.parameters()) + + def adjust_optimizer(self, + lr: int = 0.005, + betas: Tuple[float] = (0.9, 0.999), + amsgrad: bool = False): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + betas : tuple + Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['betas'] = betas + param_group['amsgrad'] = amsgrad + + def optimize(self, + iterations: int, + batch_size: int = 15, + lr: float = 0.005, + betas: Tuple[float] = (0.9, 0.999), + schedule: bool = False, + amsgrad: bool = False, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + shuffle: bool = True): + """ + Runs a round of reconstruction using the Adam optimizer + + Formerly `CDIModel.Adam_optimize` + + This calls the Reconstructor.optimize superclass method + (formerly `CDIModel.AD_optimize`) to run a round of reconstruction + once the dataloader and optimizer hyperparameters have been + set up. + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + batch_size : int + Optional, the size of the minibatches to use. + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + betas : tuple + Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + # Update the training history + self.model.training_history += ( + f'Planning {iterations} epochs of Adam, with a learning rate = ' + f'{lr}, batch size = {batch_size}, regularization_factor = ' + f'{regularization_factor}, and schedule = {schedule}.\n' + ) + + # 1) The subset statement is contained in Reconstructor.__init__ + + # 2) Set up / re-initialize the data laoder + self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) + + # 3) The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + betas=betas, + amsgrad=amsgrad) + + # 4) Set up the scheduler + if schedule: + self.scheduler = \ + t.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, + factor=0.2, + threshold=1e-9) + else: + self.scheduler = None + + # 5) This is analagous to making a call to CDIModel.AD_optimize + return super(Adam, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py new file mode 100644 index 00000000..65b27718 --- /dev/null +++ b/src/cdtools/reconstructors/base.py @@ -0,0 +1,325 @@ +"""This module contains the base Reconstructor class for performing +optimization ('reconstructions') on ptychographic/CDI models. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. + +The subclasses of Reconstructor are required to implement +their own data loaders and optimizer adjusters +""" + +import torch as t +from torch.utils import data as td +import threading +import queue +import time +from cdtools.datasets import CDataset +from cdtools.models import CDIModel +from typing import List, Union + +__all__ = ['Reconstructor'] + + +class Reconstructor: + """ + Reconstructor handles the optimization ('reconstruction') of ptychographic + models given a CDIModel (or subclass) and corresponding CDataset. + + This is a base model that defines all functions Reconstructor subclasses + must implement. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction + dataset: CDataset + The dataset to reconstruct against + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- A `torch.optim.Optimizer` that must be defined when + initializing the Reconstructor subclass. + - **scheduler** -- A `torch.optim.lr_scheduler` that may be defined during + the `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: CDataset, + subset: Union[int, List[int]] = None): + # Store parameters as attributes of Reconstructor + self.subset = subset + + # Initialize attributes that must be defined by the subclasses + self.optimizer = None + self.scheduler = None + self.data_loader = None + + # Store the original model + self.model = model + + # Store the dataset + if subset is not None: + # if subset is just one pattern, turn into a list for convenience + if isinstance(subset, int): + subset = [subset] + dataset = td.Subset(dataset, subset) + self.dataset = dataset + + def setup_dataloader(self, + batch_size: int = None, + shuffle: bool = True): + """ + Sets up / re-initializes the dataloader. + + Parameters + ---------- + batch_size : int + Optional, the size of the minibatches to use + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + if batch_size is not None: + self.data_loader = td.DataLoader(self.dataset, + batch_size=batch_size, + shuffle=shuffle) + else: + self.data_loader = td.Dataloader(self.dataset) + + def adjust_optimizer(self, **kwargs): + """ + Change hyperparameters for the utilized optimizer. + + For each optimizer, the keyword arguments should be manually defined + as parameters. + """ + raise NotImplementedError() + + def _run_epoch(self, + stop_event: threading.Event = None, + regularization_factor: Union[float, List[float]] = None, + calculation_width: int = 10): + """ + Runs one full epoch of the reconstruction. Intended to be called + by Reconstructor.optimize. + + Parameters + ---------- + stop_event : threading.Event + Default None, causes the reconstruction to stop when an exception + occurs in Optimizer.optimize. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. This does not affect the result, + but may affect the calculation speed. + + Returns + ------ + loss : float + The summed loss over the latest epoch, divided by the total + diffraction pattern intensity + """ + + # Initialize some tracking variables + normalization = 0 + loss = 0 + N = 0 + t0 = time.time() + + # The data loader is responsible for setting the minibatch + # size, so each set is a minibatch + for inputs, patterns in self.data_loader: + normalization += t.sum(patterns).cpu().numpy() + N += 1 + + def closure(): + self.optimizer.zero_grad() + + # We further break up the minibatch into a set of chunks. + # This lets us use larger minibatches than can fit + # on the GPU at once, while still doing batch processing + # for efficiency + input_chunks = [[inp[i:i + calculation_width] + for inp in inputs] + for i in range(0, len(inputs[0]), + calculation_width)] + pattern_chunks = [patterns[i:i + calculation_width] + for i in range(0, len(inputs[0]), + calculation_width)] + + total_loss = 0 + + for inp, pats in zip(input_chunks, pattern_chunks): + # This check allows for graceful exit when threading + if stop_event is not None and stop_event.is_set(): + exit() + + # Run the simulation + sim_patterns = self.model.forward(*inp) + + # Calculate the loss + if hasattr(self, 'mask'): + loss = self.model.loss(pats, + sim_patterns, + mask=self.model.mask) + else: + loss = self.model.loss(pats, + sim_patterns) + + # And accumulate the gradients + loss.backward() + + # Normalize the accumulating total loss + total_loss += loss.detach() // self.model.world_size + + # If we have a regularizer, we can calculate it separately, + # and the gradients will add to the minibatch gradient + if regularization_factor is not None \ + and hasattr(self.model, 'regularizer'): + + loss = self.model.regularizer(regularization_factor) + loss.backward() + + return total_loss + + # This takes the step for this minibatch + loss += self.optimizer.step(closure).detach().cpu().numpy() + + loss /= normalization + + # We step the scheduler after the full epoch + if self.scheduler is not None: + self.scheduler.step(loss) + + self.model.loss_history.append(loss) + self.model.epoch = len(self.model.loss_history) + self.model.latest_iteration_time = time.time() - t0 + self.model.training_history += self.model.report() + '\n' + return loss + + def optimize(self, + iterations: int, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10): + """ + Runs a round of reconstruction using the provided optimizer + + Formerly CDIModel.AD_optimize + + This is the basic automatic differentiation reconstruction tool + which all the other, algorithm-specific tools, use. It is a + generator which yields the average loss each epoch, ending after + the specified number of iterations. + + By default, the computation will be run in a separate thread. This + is done to enable live plotting with matplotlib during a + reconstruction. + + If the computation was done in the main thread, this would freeze + the plots. This behavior can be turned off by setting the keyword + argument 'thread' to False. + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. This does not affect the result, + but may affect the calculation speed. + + Yields + ------ + loss : float + The summed loss over the latest epoch, divided by the total + diffraction pattern intensity. + """ + + # We store the current optimizer as a model parameter so that + # it can be saved and loaded for checkpointing + self.current_optimizer = self.optimizer + + # If we don't want to run in a different thread, this is easy + if not thread: + for it in range(iterations): + if self.model.skip_computation(): + self.epoch = self.epoch + 1 + if len(self.model.loss_history) >= 1: + yield self.model.loss_history[-1] + else: + yield float('nan') + continue + + yield self._run_epoch(regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width) + + # But if we do want to thread, it's annoying: + else: + # Here we set up the communication with the computation thread + result_queue = queue.Queue() + stop_event = threading.Event() + + def target(): + try: + result_queue.put( + self._run_epoch(stop_event=stop_event, + regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width)) + except Exception as e: + # If something bad happens, put the exception into the + # result queue + result_queue.put(e) + + # And this actually starts and monitors the thread + for it in range(iterations): + if self.model.skip_computation(): + self.model.epoch = self.model.epoch + 1 + if len(self.model.loss_history) >= 1: + yield self.model.loss_history[-1] + else: + yield float('nan') + continue + + calc = threading.Thread(target=target, + name='calculator', + daemon=True) + try: + calc.start() + while calc.is_alive(): + if hasattr(self.model, 'figs'): + self.model.figs[0].canvas.start_event_loop(0.01) + else: + calc.join() + + except KeyboardInterrupt as e: + stop_event.set() + print('\nAsking execution thread to stop cleanly - ' + + 'please be patient.') + calc.join() + raise e + + res = result_queue.get() + + # If something went wrong in the thead, we'll get an exception + if isinstance(res, Exception): + raise res + + yield res + + # And finally, we unset the current optimizer: + self.current_optimizer = None diff --git a/src/cdtools/reconstructors/lbfgs.py b/src/cdtools/reconstructors/lbfgs.py new file mode 100644 index 00000000..0b51dfde --- /dev/null +++ b/src/cdtools/reconstructors/lbfgs.py @@ -0,0 +1,134 @@ +"""This module contains the LBFGS Reconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +the LBFGS optimizer. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +import torch as t +from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset +from cdtools.models import CDIModel +from typing import List, Union +from cdtools.reconstructors import Reconstructor + +__all__ = ['LBFGS'] + + +class LBFGS(Reconstructor): + """ + The LBFGS Reconstructor subclass handles the optimization + ('reconstruction') of ptychographic models and datasets using the LBFGS + optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.LBFGS` to + perform optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during + the `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + super().__init__(model, dataset, subset) + + # Define the optimizer for use in this subclass + self.optimizer = t.optim.LBFGS(self.model.parameters()) + + def adjust_optimizer(self, + lr: int = 0.005, + history_size: int = 2, + line_search_fn: str = None): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + history_size : int + Optional, the length of the history to use. + line_search_fn : str + Optional, either `strong_wolfe` or None + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['history_size'] = history_size + param_group['line_search_fn'] = line_search_fn + + def optimize(self, + iterations: int, + lr: float = 0.1, + history_size: int = 2, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + line_search_fn: str = None): + """ + Runs a round of reconstruction using the LBFGS optimizer + + Formerly `CDIModel.LBFGS_optimize` + + This algorithm is often less stable that Adam, however in certain + situations or geometries it can be shockingly efficient. Like all + the other optimization routines, it is defined as a generator + function which yields the average loss each epoch. + + NOTE: There is no batch size, because it is a usually a bad idea to use + LBFGS on anything but all the data at onece + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + lr : float + Optional, The learning rate (alpha) to use. Default is 0.1. + history_size : int + Optional, the length of the history to use. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + """ + # 1) The subset statement is contained in Reconstructor.__init__ + + # 2) Set up / re-initialize the data loader. For LBFGS, we load + # all the data at once. + self.setup_dataloader(batch_size=len(self.dataset)) + + # 3) The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + history_size=history_size, + line_search_fn=line_search_fn) + + # 4) This is analagous to making a call to CDIModel.AD_optimize + return super(LBFGS, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) diff --git a/src/cdtools/reconstructors/sgd.py b/src/cdtools/reconstructors/sgd.py new file mode 100644 index 00000000..f2dd7b0c --- /dev/null +++ b/src/cdtools/reconstructors/sgd.py @@ -0,0 +1,157 @@ +"""This module contains the SGD Reconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +stochastic gradient descent. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +import torch as t +from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset +from cdtools.models import CDIModel +from typing import List, Union +from cdtools.reconstructors import Reconstructor + +__all__ = ['SGD'] + + +class SGD(Reconstructor): + """ + The Adam Reconstructor subclass handles the optimization ('reconstruction') + of ptychographic models and datasets using the Adam optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.Adam` to perform + optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during the + `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + super().__init__(model, dataset, subset) + + # Define the optimizer for use in this subclass + self.optimizer = t.optim.SGD(self.model.parameters()) + + def adjust_optimizer(self, + lr: int = 0.005, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + momentum : float + Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['momentum'] = momentum + param_group['dampening'] = dampening + param_group['weight_decay'] = weight_decay + param_group['nesterov'] = nesterov + + def optimize(self, + iterations: int, + batch_size: int = None, + lr: float = 2e-7, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + shuffle: bool = True): + """ + Runs a round of reconstruction using the Adam optimizer + + Formerly `CDIModel.Adam_optimize` + + This calls the Reconstructor.optimize superclass method + (formerly `CDIModel.AD_optimize`) to run a round of reconstruction + once the dataloader and optimizer hyperparameters have been + set up. + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + batch_size : int + Optional, the size of the minibatches to use. + lr : float + Optional, The learning rate to use. The default is 2e-7. + momentum : float + Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + # 1) The subset statement is contained in Reconstructor.__init__ + + # 2) Set up / re-initialize the data laoder + if batch_size is not None: + self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) + else: + # Use default torch dataloader parameters + self.setup_dataloader(batch_size=1, shuffle=False) + + # 3) The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov) + + # 4) This is analagous to making a call to CDIModel.AD_optimize + return super(SGD, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) From 76bc2fa75b5ce0a4503c258138f1562dfdf072c5 Mon Sep 17 00:00:00 2001 From: yoshikisd Date: Fri, 1 Aug 2025 21:50:35 +0000 Subject: [PATCH 02/13] Got rid of world_size attribute --- src/cdtools/reconstructors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 65b27718..d5d34a61 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -178,7 +178,7 @@ def closure(): loss.backward() # Normalize the accumulating total loss - total_loss += loss.detach() // self.model.world_size + total_loss += loss.detach() # If we have a regularizer, we can calculate it separately, # and the gradients will add to the minibatch gradient From 404d03bbd3186f43294a9dbd437cc0080e432519 Mon Sep 17 00:00:00 2001 From: yoshikisd Date: Fri, 1 Aug 2025 21:51:52 +0000 Subject: [PATCH 03/13] Rebased CDIModel to use Reconstructors for reconstructions. --- src/cdtools/models/base.py | 478 +++++++++++++------------------------ 1 file changed, 164 insertions(+), 314 deletions(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index b455f452..41cdcd0e 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -40,6 +40,9 @@ from scipy import io from contextlib import contextmanager from cdtools.tools.data import nested_dict_to_h5, h5_to_nested_dict, nested_dict_to_numpy, nested_dict_to_torch +from cdtools.datasets import CDataset +from typing import List, Union, Tuple +import os __all__ = ['CDIModel'] @@ -316,202 +319,24 @@ def checkpoint(self, *args): self.current_checkpoint_id += 1 - - - def AD_optimize(self, iterations, data_loader, optimizer,\ - scheduler=None, regularization_factor=None, thread=True, - calculation_width=10): - """Runs a round of reconstruction using the provided optimizer - - This is the basic automatic differentiation reconstruction tool - which all the other, algorithm-specific tools, use. It is a - generator which yields the average loss each epoch, ending after - the specified number of iterations. - - By default, the computation will be run in a separate thread. This - is done to enable live plotting with matplotlib during a reconstruction. - If the computation was done in the main thread, this would freeze - the plots. This behavior can be turned off by setting the keyword - argument 'thread' to False. - - Parameters - ---------- - iterations : int - How many epochs of the algorithm to run - data_loader : torch.utils.data.DataLoader - A data loader loading the CDataset to reconstruct - optimizer : torch.optim.Optimizer - The optimizer to run the reconstruction with - scheduler : torch.optim.lr_scheduler._LRScheduler - Optional, a learning rate scheduler to use - regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method - thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation - calculation_width : int - Default 10, how many translations to pass through at once for each round of gradient accumulation. This does not affect the result, but may affect the calculation speed. - - Yields - ------ - loss : float - The summed loss over the latest epoch, divided by the total diffraction pattern intensity - """ - - def run_epoch(stop_event=None): - """Runs one full epoch of the reconstruction.""" - # First, initialize some tracking variables - normalization = 0 - loss = 0 - N = 0 - t0 = time.time() - - # The data loader is responsible for setting the minibatch - # size, so each set is a minibatch - for inputs, patterns in data_loader: - normalization += t.sum(patterns).cpu().numpy() - N += 1 - def closure(): - optimizer.zero_grad() - - # We further break up the minibatch into a set of chunks. - # This lets us use larger minibatches than can fit - # on the GPU at once, while still doing batch processing - # for efficiency - input_chunks = [[inp[i:i + calculation_width] - for inp in inputs] - for i in range(0, len(inputs[0]), - calculation_width)] - pattern_chunks = [patterns[i:i + calculation_width] - for i in range(0, len(inputs[0]), - calculation_width)] - - total_loss = 0 - for inp, pats in zip(input_chunks, pattern_chunks): - # This check allows for graceful exit when threading - if stop_event is not None and stop_event.is_set(): - exit() - - # Run the simulation - sim_patterns = self.forward(*inp) - - # Calculate the loss - if hasattr(self, 'mask'): - loss = self.loss(pats,sim_patterns, mask=self.mask) - else: - loss = self.loss(pats,sim_patterns) - - # And accumulate the gradients - loss.backward() - total_loss += loss.detach() - - # If we have a regularizer, we can calculate it separately, - # and the gradients will add to the minibatch gradient - if regularization_factor is not None \ - and hasattr(self, 'regularizer'): - loss = self.regularizer(regularization_factor) - loss.backward() - - return total_loss - - # This takes the step for this minibatch - loss += optimizer.step(closure).detach().cpu().numpy() - - - loss /= normalization - - # We step the scheduler after the full epoch - if scheduler is not None: - scheduler.step(loss) - - self.loss_history.append(loss) - self.epoch = len(self.loss_history) - self.latest_iteration_time = time.time() - t0 - self.training_history += self.report() + '\n' - return loss - - # We store the current optimizer as a model parameter so that - # it can be saved and loaded for checkpointing - self.current_optimizer = optimizer - - # If we don't want to run in a different thread, this is easy - if not thread: - for it in range(iterations): - if self.skip_computation(): - self.epoch = self.epoch + 1 - if len(self.loss_history) >= 1: - yield self.loss_history[-1] - else: - yield float('nan') - continue - - yield run_epoch() - - - # But if we do want to thread, it's annoying: - else: - # Here we set up the communication with the computation thread - result_queue = queue.Queue() - stop_event = threading.Event() - def target(): - try: - result_queue.put(run_epoch(stop_event)) - except Exception as e: - # If something bad happens, put the exception into the - # result queue - result_queue.put(e) - - # And this actually starts and monitors the thread - for it in range(iterations): - if self.skip_computation(): - self.epoch = self.epoch + 1 - if len(self.loss_history) >= 1: - yield self.loss_history[-1] - else: - yield float('nan') - continue - - calc = threading.Thread(target=target, name='calculator', daemon=True) - try: - calc.start() - while calc.is_alive(): - if hasattr(self, 'figs'): - self.figs[0].canvas.start_event_loop(0.01) - else: - calc.join() - - except KeyboardInterrupt as e: - stop_event.set() - print('\nAsking execution thread to stop cleanly - please be patient.') - calc.join() - raise e - - res = result_queue.get() - - # If something went wrong in the thead, we'll get an exception - if isinstance(res, Exception): - raise res - - yield res - - # And finally, we unset the current optimizer: - self.current_optimizer = None - def Adam_optimize( self, - iterations, - dataset, - batch_size=15, - lr=0.005, - betas=(0.9, 0.999), - schedule=False, - amsgrad=False, - subset=None, - regularization_factor=None, + iterations: int, + dataset: CDataset, + batch_size: int = 15, + lr: float = 0.005, + betas: Tuple[float] = (0.9, 0.999), + schedule: bool = False, + amsgrad: bool = False, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] = None, thread=True, calculation_width=10 ): - """Runs a round of reconstruction using the Adam optimizer + """ + Runs a round of reconstruction using the Adam optimizer from + cdtools.reconstructors.Adam. This is generally accepted to be the most robust algorithm for use with ptychography. Like all the other optimization routines, @@ -521,125 +346,143 @@ def Adam_optimize( Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. batch_size : int - Optional, the size of the minibatches to use + Optional, the size of the minibatches to use. lr : float - Optional, The learning rate (alpha) to use. Defaultis 0.005. 0.05 is typically the highest possible value with any chance of being stable - betas : tuple + Optional, The learning rate (alpha) to use. Defaultis 0.005. + 0.05 is typically the highest possible value with any chance + of being stable. + betas : tuple(float) Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). - schedule : float - Optional, whether to use the ReduceLROnPlateau scheduler + schedule : bool + Optional, whether to use the ReduceLROnPlateau scheduler. + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. subset : list(int) or int Optional, a pattern index or list of pattern indices to use - regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + regularization_factor : float or list(float). + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. calculation_width : int - Default 10, how many translations to pass through at once for each round of gradient accumulation. Does not affect the result, only the calculation speed - + Default 10, how many translations to pass through at once for + each round of gradient accumulation. Does not affect the result, + only the calculation speed. + """ - - self.training_history += ( - f'Planning {iterations} epochs of Adam, with a learning rate = ' - f'{lr}, batch size = {batch_size}, regularization_factor = ' - f'{regularization_factor}, and schedule = {schedule}.\n' - ) + # We want to have model.Adam_optimize call AND store cdtools.reconstructors.Adam + # to perform reconstructions without creating a new reconstructor each time we + # update the hyperparameters. + # + # The only way to do this is to make the Adam reconstructor an attribute + # of the model. But since the Adam reconstructor also depends on CDIModel, + # this seems to give rise to a circular import error unless + # we import cdtools.reconstructors within this method: + if not hasattr(self, 'reconstructor'): + from cdtools.reconstructors import Adam + self.reconstructor = Adam(model=self, + dataset=dataset, + subset=subset) - - if subset is not None: - # if subset is just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader - data_loader = torchdata.DataLoader(dataset, - batch_size=batch_size, - shuffle=True) - - # Define the optimizer - optimizer = t.optim.Adam( - self.parameters(), - lr = lr, - betas=betas, - amsgrad=amsgrad) - - # Define the scheduler - if schedule: - scheduler = t.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2,threshold=1e-9) - else: - scheduler = None - - return self.AD_optimize(iterations, data_loader, optimizer, - scheduler=scheduler, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) - - - def LBFGS_optimize(self, iterations, dataset, - lr=0.1,history_size=2, subset=None, - regularization_factor=None, thread=True, - calculation_width=10, line_search_fn=None): - """Runs a round of reconstruction using the L-BFGS optimizer + # Run some reconstructions + return self.reconstructor.optimize(iterations=iterations, + batch_size=batch_size, + lr=lr, + betas=betas, + schedule=schedule, + amsgrad=amsgrad, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width) + + + def LBFGS_optimize(self, + iterations: int, + dataset: CDataset, + lr: float = 0.1, + history_size: int = 2, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] =None, + thread: bool = True, + calculation_width: int = 10, + line_search_fn: str = None): + """ + Runs a round of reconstruction using the L-BFGS optimizer from + cdtools.reconstructors.LBFGS. This algorithm is often less stable that Adam, however in certain situations or geometries it can be shockingly efficient. Like all the other optimization routines, it is defined as a generator function which yields the average loss each epoch. - Note: There is no batch size, because it is a usually a bad idea to use + NOTE: There is no batch size, because it is a usually a bad idea to use LBFGS on anything but all the data at onece Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. lr : float - Optional, the learning rate to use + Optional, the learning rate to use. history_size : int Optional, the length of the history to use. subset : list(int) or int - Optional, a pattern index or list of pattern indices to ues + Optional, a pattern index or list of pattern indices to use. regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + Optional, if the model has a regularizer defined, the set of parameters + to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation. - + Default True, whether to run the computation in a separate thread to allow + interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each round of + gradient accumulation. Does not affect the result, only the calculation speed """ - if subset is not None: - # if just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader. This basically does nothing but load all the - # data at once - data_loader = torchdata.DataLoader(dataset, batch_size=len(dataset)) - - - # Define the optimizer - optimizer = t.optim.LBFGS(self.parameters(), - lr = lr, history_size=history_size, - line_search_fn=line_search_fn) - - return self.AD_optimize(iterations, data_loader, optimizer, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) - - - def SGD_optimize(self, iterations, dataset, batch_size=None, - lr=0.01, momentum=0, dampening=0, weight_decay=0, - nesterov=False, subset=None, regularization_factor=None, - thread=True, calculation_width=10): - """Runs a round of reconstruction using the SGD optimizer + # We want to have model.LBFGS_optimize store cdtools.reconstructors.LBFGS + # as an attribute to run reconstructions without generating new reconstructors + # each time CDIModel.LBFGS_optimize is called. + # + # Since the LBFGS reconstructor also depends on CDIModel, a circular import error + # arises unless we import cdtools.reconstructors within this method: + if not hasattr(self, 'reconstructor'): + from cdtools.reconstructors import LBFGS + self.reconstructor = LBFGS(model=self, + dataset=dataset, + subset=subset) + + # Run some reconstructions + return self.reconstructor.optimize(iterations=iterations, + lr=lr, + history_size=history_size, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width, + line_search_fn = line_search_fn) + + + def SGD_optimize(self, + iterations: int, + dataset: CDataset, + batch_size: int = None, + lr: float = 2e-7, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10): + """ + Runs a round of reconstruction using the SGD optimizer from + cdtools.reconstructors.SGD. This algorithm is often less stable that Adam, but it is simpler and is the basic workhorse of gradience descent. @@ -647,51 +490,58 @@ def SGD_optimize(self, iterations, dataset, batch_size=None, Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. batch_size : int - Optional, the size of the minibatches to use + Optional, the size of the minibatches to use. lr : float - Optional, the learning rate to use + Optional, the learning rate to use. momentum : float Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. subset : list(int) or int - Optional, a pattern index or list of pattern indices to use + Optional, a pattern index or list of pattern indices to use. regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. calculation_width : int - Default 1, how many translations to pass through at once for each round of gradient accumulation + Default 10, how many translations to pass through at once for each + round of gradient accumulation. """ - - if subset is not None: - # if just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader - if batch_size is not None: - data_loader = torchdata.DataLoader(dataset, batch_size=batch_size, - shuffle=True) - else: - data_loader = torchdata.DataLoader(dataset) - - - # Define the optimizer - optimizer = t.optim.SGD(self.parameters(), - lr = lr, momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - nesterov=nesterov) - - return self.AD_optimize(iterations, data_loader, optimizer, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) + # We want to have model.SGD_optimize store cdtools.reconstructors.SGD + # as an attribute to run reconstructions without generating new reconstructors + # each time CDIModel.SGD_optimize is called. + # + # Since the SGD reconstructor also depends on CDIModel, a circular import error + # arises unless we import cdtools.reconstructors within this method: + if not hasattr(self, 'reconstructor'): + from cdtools.reconstructors import SGD + self.reconstructor = SGD(model=self, + dataset=dataset, + subset=subset) + + # Run some reconstructions + return self.reconstructor.optimize(iterations=iterations, + batch_size=batch_size, + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width) def report(self): From 1e7fe2eb5f4ec9d69aa3d0bf26b1152e30f8600d Mon Sep 17 00:00:00 2001 From: yoshikisd Date: Fri, 1 Aug 2025 22:31:40 +0000 Subject: [PATCH 04/13] Added pytests for reconstructors --- tests/conftest.py | 13 +- tests/models/test_fancy_ptycho.py | 59 +----- tests/test_reconstructors.py | 311 ++++++++++++++++++++++++++++++ 3 files changed, 324 insertions(+), 59 deletions(-) create mode 100644 tests/test_reconstructors.py diff --git a/tests/conftest.py b/tests/conftest.py index 850935d2..f0faea57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import pytest import torch as t - # # # The following few fixtures define some standard data files @@ -381,6 +380,18 @@ def lab_ptycho_cxi(pytestconfig): '/examples/example_data/lab_ptycho_data.cxi' +@pytest.fixture(scope='module') +def optical_data_ss_cxi(pytestconfig): + return str(pytestconfig.rootpath) + \ + '/examples/example_data/Optical_Data_ss.cxi' + + +@pytest.fixture(scope='module') +def optical_ptycho_incoherent_pickle(pytestconfig): + return str(pytestconfig.rootpath) + \ + '/examples/example_data/Optical_ptycho_incoherent.pickle' + + @pytest.fixture(scope='module') def example_nested_dicts(pytestconfig): example_tensor = t.as_tensor(np.array([1, 4.5, 7])) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index e5422d7d..4182c93c 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -87,62 +87,5 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): model.compare(dataset) # If this fails, the reconstruction has gotten worse - assert model.loss_history[-1] < 0.001 + assert model.loss_history[-1] < 0.0013 - -@pytest.mark.slow -def test_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): - - print('\nTesting performance on the standard gold balls dataset') - - dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) - - pad = 10 - dataset.pad(pad) - - model = cdtools.models.FancyPtycho.from_dataset( - dataset, - n_modes=3, - probe_support_radius=50, - propagation_distance=2e-6, - units='um', - probe_fourier_crop=pad - ) - - model.translation_offsets.data += \ - 0.7 * t.randn_like(model.translation_offsets) - - # Not much probe intensity instability in this dataset, no need for this - model.weights.requires_grad = False - - print('Running reconstruction on provided --reconstruction_device,', - reconstruction_device) - model.to(device=reconstruction_device) - dataset.get_as(device=reconstruction_device) - - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): - print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) - - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): - print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) - - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, - schedule=True): - print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) - - model.tidy_probes() - - if show_plot: - model.inspect(dataset) - model.compare(dataset) - - # This just comes from running a reconstruction when it was working well - # and choosing a rough value. If it triggers this assertion error, - # something changed to make the final quality worse! - assert model.loss_history[-1] < 0.0001 diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py new file mode 100644 index 00000000..0f46239c --- /dev/null +++ b/tests/test_reconstructors.py @@ -0,0 +1,311 @@ +import pytest +import cdtools +import torch as t +import numpy as np +import pickle +from matplotlib import pyplot as plt +from copy import deepcopy + + +@pytest.mark.slow +def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): + """ + This test checks out several things with the Au particle dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 2) We are only using the single-GPU dataloading method + 3) Ensure `recon.model` points to the original `model` + 4) Reconstructions performed by `Adam.optimize` and + `model.Adam_optimize` calls produce identical results. + 5) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the FancyPtycho model works fine and dandy with the + Reconstructors. + """ + + print('\nTesting performance on the standard gold balls dataset ' + + 'with reconstructors.Adam') + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) + pad = 10 + dataset.pad(pad) + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # Make a copy of the model + model_recon = deepcopy(model) + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with cdtools.reconstructors.Adam.optimize ******* + print('Running reconstruction using cdtools.reconstructors.Adam.optimize' + + ' on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.Adam(model=model_recon, dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + epoch_tup = (20, 50, 100) + lr_tup = (0.005, 0.002, 0.001) + batch_size_tup = (50, 100, 100) + + for i, iterations in enumerate(epoch_tup): + for loss in recon.optimize(iterations, + lr=lr_tup[i], + batch_size=batch_size_tup[i]): + print(model_recon.report()) + if show_plot and model_recon.epoch % 10 == 0: + model_recon.inspect(dataset) + + # Check hyperparameter update + assert recon.optimizer.param_groups[0]['lr'] == lr_tup[i] + assert recon.data_loader.batch_size == batch_size_tup[i] + + # Ensure that recon does not have sampler as an attribute (only used in + # multi-GPU) + assert not hasattr(recon, 'sampler') + + # Ensure recon.model points to the original model + assert id(model_recon) == id(recon.model) + + model_recon.tidy_probes() + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # ******* Reconstructions with cdtools.CDIModel.Adam_optimize ******* + print('Running reconstruction using CDIModel.Adam_optimize on provided' + + ' reconstruction_device,', reconstruction_device) + t.manual_seed(0) + + for i, iterations in enumerate(epoch_tup): + for loss in model.Adam_optimize(iterations, + dataset, + lr=lr_tup[i], + batch_size=batch_size_tup[i]): + print(model.report()) + if show_plot and model.epoch % 10 == 0: + model.inspect(dataset) + + model.tidy_probes() + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Ensure equivalency between the model reconstructions + assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) + + # Ensure reconstructions have reached a certain loss tolerance. This just + # comes from running a reconstruction when it was working well and + # choosing a rough value. If it triggers this assertion error, something + # changed to make the final quality worse! + assert model.loss_history[-1] < 0.0001 + + +@pytest.mark.slow +def test_LBFGS_RPI(optical_data_ss_cxi, + optical_ptycho_incoherent_pickle, + reconstruction_device, + show_plot): + """ + This test checks out several things with the transmission RPI dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 2) Ensure `recon.model` points to the original `model` + 3) Reconstructions performed by `LBFGS.optimize` and + `model.LBFGS_optimize` calls produce identical results. + 4) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the RPI model works fine and dandy with the + Reconstructors. + """ + with open(optical_ptycho_incoherent_pickle, 'rb') as f: + ptycho_results = pickle.load(f) + + probe = ptycho_results['probe'] + background = ptycho_results['background'] + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(optical_data_ss_cxi) + model = cdtools.models.RPI.from_dataset(dataset, probe, [500, 500], + background=background, n_modes=2, + initialization='random') + + # Prepare two sets of models for the comparative reconstruction + model_recon = deepcopy(model) + + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with cdtools.reconstructors.LBFGS.optimize ****** + print('Running reconstruction using cdtools.reconstructors.LBFGS.' + + 'optimize on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.LBFGS(model=model_recon, dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + reg_factor_tup = ([0.05, 0.05], [0.001, 0.1]) + epoch_tup = (30, 50) + for i, iterations in enumerate(epoch_tup): + for loss in recon.optimize(iterations, + lr=0.4, + regularization_factor=reg_factor_tup[i]): + if show_plot and i == 0: + model_recon.inspect(dataset) + print(model_recon.report()) + + # Check hyperparameter update (or lack thereof) + assert recon.optimizer.param_groups[0]['lr'] == 0.4 + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # Check model pointing + assert id(model_recon) == id(recon.model) + + # ******* Reconstructions with cdtools.reconstructors.LBFGS.optimize ****** + print('Running reconstruction using CDIModel.LBFGS_optimize.' + + 'optimize on provided reconstruction_device,', reconstruction_device) + t.manual_seed(0) + for i, iterations in enumerate(epoch_tup): + for loss in model.LBFGS_optimize(iterations, + dataset, + lr=0.4, + regularization_factor=reg_factor_tup[i]): # noqa + if show_plot and i == 0: + model.inspect(dataset) + print(model.report()) + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Check loss equivalency between the two reconstructions + assert np.allclose(model.loss_history[-1], model_recon.loss_history[-1]) + + # The final loss when testing this was 2.28607e-3. Based on this, we set + # a threshold of 2.3e-3 for the tested loss. If this value has been + # exceeded, the reconstructions have gotten worse. + assert model.loss_history[-1] < 0.0023 + + +@pytest.mark.slow +def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): + """ + This test checks out several things with the Au particle dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 3) Ensure `recon.model` points to the original `model` + 4) Reconstructions performed by `SGD.optimize` and + `model.SGD_optimize` calls produce identical results. + 5) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the FancyPtycho model works fine and dandy with the + Reconstructors. + + The hyperparameters used in this test are not optimized to produce + a super-high-quality reconstruction. Instead, I just need A reconstruction + to do some kind of comparative assessment. + """ + print('\nTesting performance on the standard gold balls dataset ' + + 'with reconstructors.SGD') + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) + pad = 10 + dataset.pad(pad) + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # Make a copy of the model + model_recon = deepcopy(model) + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with cdtools.reconstructors.SGD.optimize ******* + print('Running reconstruction using cdtools.reconstructors.SGD.optimize' + + ' on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.SGD(model=model_recon, dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + epochs = 50 + lr = 0.00000005 + batch_size = 40 + + for loss in recon.optimize(epochs, + lr=lr, + batch_size=batch_size): + print(model_recon.report()) + if show_plot and model_recon.epoch % 10 == 0: + model_recon.inspect(dataset) + + # Check hyperparameter update + assert recon.optimizer.param_groups[0]['lr'] == lr + assert recon.data_loader.batch_size == batch_size + + # Ensure that recon does not have sampler as an attribute (only used in + # multi-GPU) + assert not hasattr(recon, 'sampler') + + # Ensure recon.model points to the original model + assert id(model_recon) == id(recon.model) + + model_recon.tidy_probes() + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # ******* Reconstructions with cdtools.CDIModel.SGD_optimize ******* + print('Running reconstruction using CDIModel.SGD_optimize on provided' + + ' reconstruction_device,', reconstruction_device) + t.manual_seed(0) + + for loss in model.SGD_optimize(epochs, + dataset, + lr=lr, + batch_size=batch_size): + print(model.report()) + if show_plot and model.epoch % 10 == 0: + model.inspect(dataset) + + model.tidy_probes() + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Ensure equivalency between the model reconstructions + assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) + + # The final loss when testing this was 7.12188e-4. Based on this, we set + # a threshold of 7.2e-4 for the tested loss. If this value has been + # exceeded, the reconstructions have gotten worse. + assert model.loss_history[-1] < 0.00072 From 82ce845385170afbad9845e73be1d487e17b7373 Mon Sep 17 00:00:00 2001 From: yoshikisd Date: Mon, 4 Aug 2025 16:51:24 +0000 Subject: [PATCH 05/13] Changed the names of the reconstructors and updated the old optimizer class documentation --- src/cdtools/models/base.py | 167 +++++++++++++------------ src/cdtools/reconstructors/__init__.py | 20 +-- src/cdtools/reconstructors/adam.py | 14 +-- src/cdtools/reconstructors/base.py | 2 +- src/cdtools/reconstructors/lbfgs.py | 16 +-- src/cdtools/reconstructors/sgd.py | 18 +-- tests/test_reconstructors.py | 29 +++-- 7 files changed, 139 insertions(+), 127 deletions(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 41cdcd0e..635cdfbd 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -319,7 +319,6 @@ def checkpoint(self, *args): self.current_checkpoint_id += 1 - def Adam_optimize( self, iterations: int, @@ -336,7 +335,7 @@ def Adam_optimize( ): """ Runs a round of reconstruction using the Adam optimizer from - cdtools.reconstructors.Adam. + cdtools.reconstructors.AdamReconstructor. This is generally accepted to be the most robust algorithm for use with ptychography. Like all the other optimization routines, @@ -375,45 +374,45 @@ def Adam_optimize( only the calculation speed. """ - # We want to have model.Adam_optimize call AND store cdtools.reconstructors.Adam - # to perform reconstructions without creating a new reconstructor each time we - # update the hyperparameters. - # - # The only way to do this is to make the Adam reconstructor an attribute - # of the model. But since the Adam reconstructor also depends on CDIModel, - # this seems to give rise to a circular import error unless - # we import cdtools.reconstructors within this method: + # We want to have model.Adam_optimize call AND store + # cdtools.reconstructors.AdamReconstructor to perform reconstructions + # without creating a new reconstructor each time we update the + # hyperparameters. + # + # The only way to do this is to make the Adam reconstructor an + # attribute of the model. But since the Adam reconstructor also + # depends on CDIModel, this seems to give rise to a circular import + # error unless we import cdtools.reconstructors within this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import Adam - self.reconstructor = Adam(model=self, - dataset=dataset, - subset=subset) + from cdtools.reconstructors import AdamReconstructor + self.reconstructor = AdamReconstructor(model=self, + dataset=dataset, + subset=subset) # Run some reconstructions return self.reconstructor.optimize(iterations=iterations, - batch_size=batch_size, - lr=lr, - betas=betas, - schedule=schedule, - amsgrad=amsgrad, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) - - - def LBFGS_optimize(self, - iterations: int, + batch_size=batch_size, + lr=lr, + betas=betas, + schedule=schedule, + amsgrad=amsgrad, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width) + + def LBFGS_optimize(self, + iterations: int, dataset: CDataset, lr: float = 0.1, - history_size: int = 2, + history_size: int = 2, subset: Union[int, List[int]] = None, - regularization_factor: Union[float, List[float]] =None, + regularization_factor: Union[float, List[float]] =None, thread: bool = True, - calculation_width: int = 10, + calculation_width: int = 10, line_search_fn: str = None): """ Runs a round of reconstruction using the L-BFGS optimizer from - cdtools.reconstructors.LBFGS. + cdtools.reconstructors.LBFGSReconstructor. This algorithm is often less stable that Adam, however in certain situations or geometries it can be shockingly efficient. Like all @@ -436,53 +435,55 @@ def LBFGS_optimize(self, subset : list(int) or int Optional, a pattern index or list of pattern indices to use. regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters - to pass the regularizer method. + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow - interaction with plots during computation. + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. calculation_width : int - Default 10, how many translations to pass through at once for each round of - gradient accumulation. Does not affect the result, only the calculation speed + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. """ - # We want to have model.LBFGS_optimize store cdtools.reconstructors.LBFGS - # as an attribute to run reconstructions without generating new reconstructors - # each time CDIModel.LBFGS_optimize is called. - # - # Since the LBFGS reconstructor also depends on CDIModel, a circular import error - # arises unless we import cdtools.reconstructors within this method: + # We want to have model.LBFGS_optimize store + # cdtools.reconstructors.LBFGSReconstructor as an attribute to run + # reconstructions without generating new reconstructors each time + # CDIModel.LBFGS_optimize is called. + # + # Since the LBFGS reconstructor also depends on CDIModel, a circular + # import error arises unless we import cdtools.reconstructors within + # this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import LBFGS - self.reconstructor = LBFGS(model=self, - dataset=dataset, - subset=subset) + from cdtools.reconstructors import LBFGSReconstructor + self.reconstructor = LBFGSReconstructor(model=self, + dataset=dataset, + subset=subset) # Run some reconstructions return self.reconstructor.optimize(iterations=iterations, - lr=lr, - history_size=history_size, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width, - line_search_fn = line_search_fn) - + lr=lr, + history_size=history_size, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + line_search_fn=line_search_fn) def SGD_optimize(self, - iterations: int, - dataset: CDataset, + iterations: int, + dataset: CDataset, batch_size: int = None, - lr: float = 2e-7, - momentum: float = 0, - dampening: float = 0, + lr: float = 2e-7, + momentum: float = 0, + dampening: float = 0, weight_decay: float = 0, - nesterov: bool = False, - subset: Union[int, List[int]] = None, + nesterov: bool = False, + subset: Union[int, List[int]] = None, regularization_factor: Union[float, List[float]] = None, - thread: bool = True, + thread: bool = True, calculation_width: int = 10): """ Runs a round of reconstruction using the SGD optimizer from - cdtools.reconstructors.SGD. + cdtools.reconstructors.SGDReconstructor. This algorithm is often less stable that Adam, but it is simpler and is the basic workhorse of gradience descent. @@ -519,29 +520,31 @@ def SGD_optimize(self, round of gradient accumulation. """ - # We want to have model.SGD_optimize store cdtools.reconstructors.SGD - # as an attribute to run reconstructions without generating new reconstructors - # each time CDIModel.SGD_optimize is called. - # - # Since the SGD reconstructor also depends on CDIModel, a circular import error - # arises unless we import cdtools.reconstructors within this method: + # We want to have model.SGD_optimize store + # cdtools.reconstructors.SGDReconstructor as an attribute to run + # reconstructions without generating new reconstructors each time + # CDIModel.SGD_optimize is called. + # + # Since the SGD reconstructor also depends on CDIModel, a circular + # import error arises unless we import cdtools.reconstructors within + # this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import SGD - self.reconstructor = SGD(model=self, - dataset=dataset, - subset=subset) - + from cdtools.reconstructors import SGDReconstructor + self.reconstructor = SGDReconstructor(model=self, + dataset=dataset, + subset=subset) + # Run some reconstructions return self.reconstructor.optimize(iterations=iterations, - batch_size=batch_size, - lr=lr, - momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - nesterov=nesterov, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) + batch_size=batch_size, + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width) def report(self): diff --git a/src/cdtools/reconstructors/__init__.py b/src/cdtools/reconstructors/__init__.py index 84b96ab4..84a1a818 100644 --- a/src/cdtools/reconstructors/__init__.py +++ b/src/cdtools/reconstructors/__init__.py @@ -1,16 +1,22 @@ -"""This module contains optimizers for performing reconstructions +""" +Module `cdtools.tools.reconstructors` contains the `Reconstructor` class and +subclasses which run the ptychography reconstructions on a given model and +dataset. +The reconstructors are designed to resemble so-called 'Trainer' classes that +(in the language of the AI/ML folks) handles the 'training' of a model given +some dataset and optimizer. """ # We define __all__ to be sure that import * only imports what we want __all__ = [ 'Reconstructor', - 'Adam', - 'LBFGS', - 'SGD' + 'AdamReconstructor', + 'LBFGSReconstructor', + 'SGDReconstructor' ] from cdtools.reconstructors.base import Reconstructor -from cdtools.reconstructors.adam import Adam -from cdtools.reconstructors.lbfgs import LBFGS -from cdtools.reconstructors.sgd import SGD +from cdtools.reconstructors.adam import AdamReconstructor +from cdtools.reconstructors.lbfgs import LBFGSReconstructor +from cdtools.reconstructors.sgd import SGDReconstructor diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py index 5a489a06..6eecc9c8 100644 --- a/src/cdtools/reconstructors/adam.py +++ b/src/cdtools/reconstructors/adam.py @@ -1,4 +1,4 @@ -"""This module contains the Adam Reconstructor subclass for performing +"""This module contains the AdamReconstructor subclass for performing optimization ('reconstructions') on ptychographic/CDI models using the Adam optimizer. @@ -12,10 +12,10 @@ from typing import Tuple, List, Union from cdtools.reconstructors import Reconstructor -__all__ = ['Adam'] +__all__ = ['AdamReconstructor'] -class Adam(Reconstructor): +class AdamReconstructor(Reconstructor): """ The Adam Reconstructor subclass handles the optimization ('reconstruction') of ptychographic models and datasets using the Adam optimizer. @@ -154,7 +154,7 @@ def optimize(self, self.scheduler = None # 5) This is analagous to making a call to CDIModel.AD_optimize - return super(Adam, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + return super(AdamReconstructor, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index d5d34a61..fd2cbb22 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -178,7 +178,7 @@ def closure(): loss.backward() # Normalize the accumulating total loss - total_loss += loss.detach() + total_loss += loss.detach() # If we have a regularizer, we can calculate it separately, # and the gradients will add to the minibatch gradient diff --git a/src/cdtools/reconstructors/lbfgs.py b/src/cdtools/reconstructors/lbfgs.py index 0b51dfde..8bbd7ae8 100644 --- a/src/cdtools/reconstructors/lbfgs.py +++ b/src/cdtools/reconstructors/lbfgs.py @@ -1,4 +1,4 @@ -"""This module contains the LBFGS Reconstructor subclass for performing +"""This module contains the LBFGSReconstructor subclass for performing optimization ('reconstructions') on ptychographic/CDI models using the LBFGS optimizer. @@ -12,12 +12,12 @@ from typing import List, Union from cdtools.reconstructors import Reconstructor -__all__ = ['LBFGS'] +__all__ = ['LBFGSReconstructor'] -class LBFGS(Reconstructor): +class LBFGSReconstructor(Reconstructor): """ - The LBFGS Reconstructor subclass handles the optimization + The LBFGSReconstructor subclass handles the optimization ('reconstruction') of ptychographic models and datasets using the LBFGS optimizer. @@ -128,7 +128,7 @@ def optimize(self, line_search_fn=line_search_fn) # 4) This is analagous to making a call to CDIModel.AD_optimize - return super(LBFGS, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + return super(LBFGSReconstructor, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) diff --git a/src/cdtools/reconstructors/sgd.py b/src/cdtools/reconstructors/sgd.py index f2dd7b0c..a8f2ebc0 100644 --- a/src/cdtools/reconstructors/sgd.py +++ b/src/cdtools/reconstructors/sgd.py @@ -1,4 +1,4 @@ -"""This module contains the SGD Reconstructor subclass for performing +"""This module contains the SGDReconstructor subclass for performing optimization ('reconstructions') on ptychographic/CDI models using stochastic gradient descent. @@ -12,13 +12,13 @@ from typing import List, Union from cdtools.reconstructors import Reconstructor -__all__ = ['SGD'] +__all__ = ['SGDReconstructor'] -class SGD(Reconstructor): +class SGDReconstructor(Reconstructor): """ - The Adam Reconstructor subclass handles the optimization ('reconstruction') - of ptychographic models and datasets using the Adam optimizer. + The SGDReconstructor subclass handles the optimization ('reconstruction') + of ptychographic models and datasets using the SGD optimizer. Parameters ---------- @@ -151,7 +151,7 @@ def optimize(self, nesterov=nesterov) # 4) This is analagous to making a call to CDIModel.AD_optimize - return super(SGD, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + return super(SGDReconstructor, self).optimize(iterations, + regularization_factor, + thread, + calculation_width) diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 0f46239c..3995b6fd 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -24,7 +24,7 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): """ print('\nTesting performance on the standard gold balls dataset ' + - 'with reconstructors.Adam') + 'with reconstructors.AdamReconstructor') dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) pad = 10 @@ -48,11 +48,12 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): model_recon.to(device=reconstruction_device) dataset.get_as(device=reconstruction_device) - # ******* Reconstructions with cdtools.reconstructors.Adam.optimize ******* - print('Running reconstruction using cdtools.reconstructors.Adam.optimize' + + # ******* Reconstructions with AdamReconstructor.optimize ******* + print('Running reconstruction using AdamReconstructor.optimize' + ' on provided reconstruction_device,', reconstruction_device) - recon = cdtools.reconstructors.Adam(model=model_recon, dataset=dataset) + recon = cdtools.reconstructors.AdamReconstructor(model=model_recon, + dataset=dataset) t.manual_seed(0) # Run a reconstruction @@ -85,7 +86,7 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): model_recon.inspect(dataset) model_recon.compare(dataset) - # ******* Reconstructions with cdtools.CDIModel.Adam_optimize ******* + # ******* Reconstructions with CDIModel.Adam_optimize ******* print('Running reconstruction using CDIModel.Adam_optimize on provided' + ' reconstruction_device,', reconstruction_device) t.manual_seed(0) @@ -150,11 +151,12 @@ def test_LBFGS_RPI(optical_data_ss_cxi, model_recon.to(device=reconstruction_device) dataset.get_as(device=reconstruction_device) - # ******* Reconstructions with cdtools.reconstructors.LBFGS.optimize ****** - print('Running reconstruction using cdtools.reconstructors.LBFGS.' + + # ******* Reconstructions with LBFGSReconstructor.optimize ****** + print('Running reconstruction using LBFGSReconstructor.' + 'optimize on provided reconstruction_device,', reconstruction_device) - recon = cdtools.reconstructors.LBFGS(model=model_recon, dataset=dataset) + recon = cdtools.reconstructors.LBFGSReconstructor(model=model_recon, + dataset=dataset) t.manual_seed(0) # Run a reconstruction @@ -178,7 +180,7 @@ def test_LBFGS_RPI(optical_data_ss_cxi, # Check model pointing assert id(model_recon) == id(recon.model) - # ******* Reconstructions with cdtools.reconstructors.LBFGS.optimize ****** + # ******* Reconstructions with CDIModel.LBFGS_optimize ****** print('Running reconstruction using CDIModel.LBFGS_optimize.' + 'optimize on provided reconstruction_device,', reconstruction_device) t.manual_seed(0) @@ -223,7 +225,7 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): to do some kind of comparative assessment. """ print('\nTesting performance on the standard gold balls dataset ' + - 'with reconstructors.SGD') + 'with reconstructors.SGDReconstructor') dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) pad = 10 @@ -247,11 +249,12 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): model_recon.to(device=reconstruction_device) dataset.get_as(device=reconstruction_device) - # ******* Reconstructions with cdtools.reconstructors.SGD.optimize ******* - print('Running reconstruction using cdtools.reconstructors.SGD.optimize' + + # ******* Reconstructions with SGDReconstructor.optimize ******* + print('Running reconstruction using SGDReconstructor.optimize' + ' on provided reconstruction_device,', reconstruction_device) - recon = cdtools.reconstructors.SGD(model=model_recon, dataset=dataset) + recon = cdtools.reconstructors.SGDReconstructor(model=model_recon, + dataset=dataset) t.manual_seed(0) # Run a reconstruction From 766a7d1adf038f893e35dcf82d497d357521017b Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Mon, 13 Oct 2025 14:20:38 +0200 Subject: [PATCH 06/13] Move to better supported strategy for avoiding circular imports --- src/cdtools/models/base.py | 17 +---------------- src/cdtools/reconstructors/adam.py | 9 +++++++-- src/cdtools/reconstructors/base.py | 9 +++++++-- src/cdtools/reconstructors/lbfgs.py | 10 ++++++++-- src/cdtools/reconstructors/sgd.py | 10 ++++++++-- 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 635cdfbd..5d3ac9af 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -40,6 +40,7 @@ from scipy import io from contextlib import contextmanager from cdtools.tools.data import nested_dict_to_h5, h5_to_nested_dict, nested_dict_to_numpy, nested_dict_to_torch +from cdtools.reconstructors import AdamReconstructor, LBFGSReconstructor, SGDReconstructor from cdtools.datasets import CDataset from typing import List, Union, Tuple import os @@ -378,13 +379,7 @@ def Adam_optimize( # cdtools.reconstructors.AdamReconstructor to perform reconstructions # without creating a new reconstructor each time we update the # hyperparameters. - # - # The only way to do this is to make the Adam reconstructor an - # attribute of the model. But since the Adam reconstructor also - # depends on CDIModel, this seems to give rise to a circular import - # error unless we import cdtools.reconstructors within this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import AdamReconstructor self.reconstructor = AdamReconstructor(model=self, dataset=dataset, subset=subset) @@ -449,12 +444,7 @@ def LBFGS_optimize(self, # cdtools.reconstructors.LBFGSReconstructor as an attribute to run # reconstructions without generating new reconstructors each time # CDIModel.LBFGS_optimize is called. - # - # Since the LBFGS reconstructor also depends on CDIModel, a circular - # import error arises unless we import cdtools.reconstructors within - # this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import LBFGSReconstructor self.reconstructor = LBFGSReconstructor(model=self, dataset=dataset, subset=subset) @@ -524,12 +514,7 @@ def SGD_optimize(self, # cdtools.reconstructors.SGDReconstructor as an attribute to run # reconstructions without generating new reconstructors each time # CDIModel.SGD_optimize is called. - # - # Since the SGD reconstructor also depends on CDIModel, a circular - # import error arises unless we import cdtools.reconstructors within - # this method: if not hasattr(self, 'reconstructor'): - from cdtools.reconstructors import SGDReconstructor self.reconstructor = SGDReconstructor(model=self, dataset=dataset, subset=subset) diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py index 6eecc9c8..4c54f1f5 100644 --- a/src/cdtools/reconstructors/adam.py +++ b/src/cdtools/reconstructors/adam.py @@ -6,12 +6,17 @@ 'Trainer' classes that (in the language of the AI/ML folks) handles the 'training' of a model given some dataset and optimizer. """ +from __future__ import annotations +from typing import TYPE_CHECKING + import torch as t -from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset -from cdtools.models import CDIModel from typing import Tuple, List, Union from cdtools.reconstructors import Reconstructor +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + __all__ = ['AdamReconstructor'] diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index fd2cbb22..1937661f 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -8,16 +8,21 @@ The subclasses of Reconstructor are required to implement their own data loaders and optimizer adjusters """ +from __future__ import annotations +from typing import TYPE_CHECKING import torch as t from torch.utils import data as td import threading import queue import time -from cdtools.datasets import CDataset -from cdtools.models import CDIModel from typing import List, Union +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets import CDataset + + __all__ = ['Reconstructor'] diff --git a/src/cdtools/reconstructors/lbfgs.py b/src/cdtools/reconstructors/lbfgs.py index 8bbd7ae8..60c6f979 100644 --- a/src/cdtools/reconstructors/lbfgs.py +++ b/src/cdtools/reconstructors/lbfgs.py @@ -6,12 +6,18 @@ 'Trainer' classes that (in the language of the AI/ML folks) handles the 'training' of a model given some dataset and optimizer. """ +from __future__ import annotations +from typing import TYPE_CHECKING + import torch as t -from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset -from cdtools.models import CDIModel from typing import List, Union from cdtools.reconstructors import Reconstructor +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + + __all__ = ['LBFGSReconstructor'] diff --git a/src/cdtools/reconstructors/sgd.py b/src/cdtools/reconstructors/sgd.py index a8f2ebc0..5b7bf4e2 100644 --- a/src/cdtools/reconstructors/sgd.py +++ b/src/cdtools/reconstructors/sgd.py @@ -6,12 +6,18 @@ 'Trainer' classes that (in the language of the AI/ML folks) handles the 'training' of a model given some dataset and optimizer. """ +from __future__ import annotations +from typing import TYPE_CHECKING + import torch as t -from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset -from cdtools.models import CDIModel from typing import List, Union from cdtools.reconstructors import Reconstructor +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + + __all__ = ['SGDReconstructor'] From a089dc3b7266ff6120f50e17638a8f8ba6e57ce2 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Mon, 13 Oct 2025 14:26:07 +0200 Subject: [PATCH 07/13] Switch to a simpler pattern for the CDIModel._optimize functions that preserves the old behavior when the old pattern is used --- src/cdtools/models/base.py | 99 ++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 51 deletions(-) diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index 5d3ac9af..df347b63 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -375,25 +375,24 @@ def Adam_optimize( only the calculation speed. """ - # We want to have model.Adam_optimize call AND store - # cdtools.reconstructors.AdamReconstructor to perform reconstructions - # without creating a new reconstructor each time we update the - # hyperparameters. - if not hasattr(self, 'reconstructor'): - self.reconstructor = AdamReconstructor(model=self, - dataset=dataset, - subset=subset) + reconstructor = AdamReconstructor( + model=self, + dataset=dataset, + subset=subset, + ) # Run some reconstructions - return self.reconstructor.optimize(iterations=iterations, - batch_size=batch_size, - lr=lr, - betas=betas, - schedule=schedule, - amsgrad=amsgrad, - regularization_factor=regularization_factor, # noqa - thread=thread, - calculation_width=calculation_width) + return reconstructor.optimize( + iterations=iterations, + batch_size=batch_size, + lr=lr, + betas=betas, + schedule=schedule, + amsgrad=amsgrad, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + ) def LBFGS_optimize(self, iterations: int, @@ -440,24 +439,23 @@ def LBFGS_optimize(self, round of gradient accumulation. Does not affect the result, only the calculation speed. """ - # We want to have model.LBFGS_optimize store - # cdtools.reconstructors.LBFGSReconstructor as an attribute to run - # reconstructions without generating new reconstructors each time - # CDIModel.LBFGS_optimize is called. - if not hasattr(self, 'reconstructor'): - self.reconstructor = LBFGSReconstructor(model=self, - dataset=dataset, - subset=subset) + reconstructor = LBFGSReconstructor( + model=self, + dataset=dataset, + subset=subset, + ) # Run some reconstructions - return self.reconstructor.optimize(iterations=iterations, - lr=lr, - history_size=history_size, - regularization_factor=regularization_factor, # noqa - thread=thread, - calculation_width=calculation_width, - line_search_fn=line_search_fn) - + return reconstructor.optimize( + iterations=iterations, + lr=lr, + history_size=history_size, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + line_search_fn=line_search_fn, + ) + def SGD_optimize(self, iterations: int, dataset: CDataset, @@ -510,26 +508,25 @@ def SGD_optimize(self, round of gradient accumulation. """ - # We want to have model.SGD_optimize store - # cdtools.reconstructors.SGDReconstructor as an attribute to run - # reconstructions without generating new reconstructors each time - # CDIModel.SGD_optimize is called. - if not hasattr(self, 'reconstructor'): - self.reconstructor = SGDReconstructor(model=self, - dataset=dataset, - subset=subset) + reconstructor = SGDReconstructor( + model=self, + dataset=dataset, + subset=subset, + ) # Run some reconstructions - return self.reconstructor.optimize(iterations=iterations, - batch_size=batch_size, - lr=lr, - momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - nesterov=nesterov, - regularization_factor=regularization_factor, # noqa - thread=thread, - calculation_width=calculation_width) + return reconstructor.optimize( + iterations=iterations, + batch_size=batch_size, + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + ) def report(self): From aa684f27eb447ae7efd56bed0652943c682cca6c Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Mon, 13 Oct 2025 19:20:42 +0200 Subject: [PATCH 08/13] Change the pattern for Reconstructor so that the optimizer is defined at object creation, and move the dataloader creation logic to the base optimize() function as it was reused in all subclasses --- src/cdtools/reconstructors/adam.py | 45 ++++++++++------ src/cdtools/reconstructors/base.py | 83 +++++++++++++++++++++++------ src/cdtools/reconstructors/lbfgs.py | 33 ++++++------ src/cdtools/reconstructors/sgd.py | 39 +++++++------- 4 files changed, 133 insertions(+), 67 deletions(-) diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py index 4c54f1f5..ac11ee53 100644 --- a/src/cdtools/reconstructors/adam.py +++ b/src/cdtools/reconstructors/adam.py @@ -51,10 +51,12 @@ def __init__(self, dataset: Ptycho2DDataset, subset: List[int] = None): - super().__init__(model, dataset, subset) - # Define the optimizer for use in this subclass - self.optimizer = t.optim.Adam(self.model.parameters()) + optimizer = t.optim.Adam(model.parameters()) + + super().__init__(model, dataset, optimizer, subset=subset) + + def adjust_optimizer(self, lr: int = 0.005, @@ -79,11 +81,13 @@ def adjust_optimizer(self, param_group['betas'] = betas param_group['amsgrad'] = amsgrad + def optimize(self, iterations: int, batch_size: int = 15, lr: float = 0.005, betas: Tuple[float] = (0.9, 0.999), + custom_data_loader = None, schedule: bool = False, amsgrad: bool = False, regularization_factor: Union[float, List[float]] = None, @@ -99,6 +103,12 @@ def optimize(self, (formerly `CDIModel.AD_optimize`) to run a round of reconstruction once the dataloader and optimizer hyperparameters have been set up. + + The `batch_size` parameter sets the batch size for the default + dataloader. If a custom data loader is desired, it can be passed + in to the `custom_data_loader` argument, which will override the + `batch_size` and `shuffle` parameters + Parameters ---------- @@ -115,6 +125,9 @@ def optimize(self, schedule : bool Optional, create a learning rate scheduler (torch.optim.lr_scheduler._LRScheduler). + custom_data_loader : t.utils.data.DataLoader + Optional, a custom DataLoader to use. If set, will override + batch_size. amsgrad : bool Optional, whether to use the AMSGrad variant of this algorithm. regularization_factor : float or list(float) @@ -138,18 +151,13 @@ def optimize(self, f'{regularization_factor}, and schedule = {schedule}.\n' ) - # 1) The subset statement is contained in Reconstructor.__init__ - - # 2) Set up / re-initialize the data laoder - self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) - - # 3) The optimizer is created in self.__init__, but the - # hyperparameters need to be set up with self.adjust_optimizer + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer self.adjust_optimizer(lr=lr, betas=betas, amsgrad=amsgrad) - # 4) Set up the scheduler + # Set up the scheduler if schedule: self.scheduler = \ t.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, @@ -158,8 +166,13 @@ def optimize(self, else: self.scheduler = None - # 5) This is analagous to making a call to CDIModel.AD_optimize - return super(AdamReconstructor, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + # Now, we run the optimize routine defined in the base class + return super(AdamReconstructor, self).optimize( + iterations, + batch_size=batch_size, + custom_data_loader=custom_data_loader, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width, + shuffle=shuffle, + ) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 1937661f..bae992ba 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -40,6 +40,8 @@ class Reconstructor: Model for CDI/ptychography reconstruction dataset: CDataset The dataset to reconstruct against + optimizer: torch.optim.Optimizer + The optimizer to use for the reconstruction subset : list(int) or int Optional, a pattern index or list of pattern indices to use @@ -55,31 +57,32 @@ class Reconstructor: def __init__(self, model: CDIModel, dataset: CDataset, + optimizer: t.optim.Optimizer, subset: Union[int, List[int]] = None): + # Store parameters as attributes of Reconstructor - self.subset = subset - - # Initialize attributes that must be defined by the subclasses - self.optimizer = None - self.scheduler = None - self.data_loader = None - - # Store the original model self.model = model + self.optimizer = optimizer - # Store the dataset + # Store the dataset, clipping it to a subset if needed if subset is not None: # if subset is just one pattern, turn into a list for convenience if isinstance(subset, int): subset = [subset] dataset = td.Subset(dataset, subset) + self.dataset = dataset + # Initialize attributes that must be defined by the subclasses + self.scheduler = None + self.data_loader = None + + def setup_dataloader(self, batch_size: int = None, shuffle: bool = True): """ - Sets up / re-initializes the dataloader. + Sets up or re-initializes the dataloader. Parameters ---------- @@ -96,6 +99,7 @@ def setup_dataloader(self, else: self.data_loader = td.Dataloader(self.dataset) + def adjust_optimizer(self, **kwargs): """ Change hyperparameters for the utilized optimizer. @@ -105,7 +109,8 @@ def adjust_optimizer(self, **kwargs): """ raise NotImplementedError() - def _run_epoch(self, + + def run_epoch(self, stop_event: threading.Event = None, regularization_factor: Union[float, List[float]] = None, calculation_width: int = 10): @@ -133,6 +138,18 @@ def _run_epoch(self, diffraction pattern intensity """ + # Setting this as an explicit catch makes me feel more comfortable + # exposing it as a public method. This way a user won't be confused + # if they try to use this directly + if self.data_loader is None: + raise RuntimeError( + 'No data loader was defined. Please run ' + 'Reconstructor.setup_dataloader() before running ' + 'Reconstructor.run_epoch(), or use Reconstructor.optimize(), ' + 'which does it automatically.' + ) + + # Initialize some tracking variables normalization = 0 loss = 0 @@ -212,9 +229,12 @@ def closure(): def optimize(self, iterations: int, + batch_size: int = 1, + custom_data_loader = None, regularization_factor: Union[float, List[float]] = None, thread: bool = True, - calculation_width: int = 10): + calculation_width: int = 10, + shuffle=True): """ Runs a round of reconstruction using the provided optimizer @@ -233,10 +253,25 @@ def optimize(self, the plots. This behavior can be turned off by setting the keyword argument 'thread' to False. + The `batch_size` parameter sets the batch size for the default + dataloader. If a custom data loader is desired, it can be passed + in to the `custom_data_loader` argument, which will override the + `batch_size` and `shuffle` parameters + + Please see `AdamReconstructor.optimize()` for an example of how to + override this function when designing a subclass + Parameters ---------- iterations : int How many epochs of the algorithm to run. + batch_size : int + Optional, the batch size to use. Default is 1. This is typically + overridden by subclasses with an appropriate default for the + specific optimizer. + custom_data_loader : torch.utils.data.DataLoader + Optional, a custom DataLoader to use. Will override batch_size + if set. regularization_factor : float or list(float) Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method. @@ -247,6 +282,10 @@ def optimize(self, Default 10, how many translations to pass through at once for each round of gradient accumulation. This does not affect the result, but may affect the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + Yields ------ @@ -255,6 +294,11 @@ def optimize(self, diffraction pattern intensity. """ + if custom_data_loader is None: + self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) + else: + self.data_loader = custom_data_loader + # We store the current optimizer as a model parameter so that # it can be saved and loaded for checkpointing self.current_optimizer = self.optimizer @@ -270,8 +314,10 @@ def optimize(self, yield float('nan') continue - yield self._run_epoch(regularization_factor=regularization_factor, # noqa - calculation_width=calculation_width) + yield self.run_epoch( + regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width, + ) # But if we do want to thread, it's annoying: else: @@ -282,9 +328,12 @@ def optimize(self, def target(): try: result_queue.put( - self._run_epoch(stop_event=stop_event, - regularization_factor=regularization_factor, # noqa - calculation_width=calculation_width)) + self.run_epoch( + stop_event=stop_event, + regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width, + ) + ) except Exception as e: # If something bad happens, put the exception into the # result queue diff --git a/src/cdtools/reconstructors/lbfgs.py b/src/cdtools/reconstructors/lbfgs.py index 60c6f979..8907782a 100644 --- a/src/cdtools/reconstructors/lbfgs.py +++ b/src/cdtools/reconstructors/lbfgs.py @@ -53,10 +53,16 @@ def __init__(self, dataset: Ptycho2DDataset, subset: List[int] = None): - super().__init__(model, dataset, subset) - # Define the optimizer for use in this subclass - self.optimizer = t.optim.LBFGS(self.model.parameters()) + optimizer = t.optim.LBFGS(model.parameters()) + + super().__init__( + model, + dataset, + optimizer, + subset=subset, + ) + def adjust_optimizer(self, lr: int = 0.005, @@ -121,20 +127,17 @@ def optimize(self, round of gradient accumulation. Does not affect the result, only the calculation speed. """ - # 1) The subset statement is contained in Reconstructor.__init__ - - # 2) Set up / re-initialize the data loader. For LBFGS, we load - # all the data at once. - self.setup_dataloader(batch_size=len(self.dataset)) - # 3) The optimizer is created in self.__init__, but the - # hyperparameters need to be set up with self.adjust_optimizer + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer self.adjust_optimizer(lr=lr, history_size=history_size, line_search_fn=line_search_fn) - # 4) This is analagous to making a call to CDIModel.AD_optimize - return super(LBFGSReconstructor, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + # Now, we run the optimize routine defined in the base class + return super(LBFGSReconstructor, self).optimize( + iterations, + batch_size=len(self.dataset), + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width) diff --git a/src/cdtools/reconstructors/sgd.py b/src/cdtools/reconstructors/sgd.py index 5b7bf4e2..cb1e26b1 100644 --- a/src/cdtools/reconstructors/sgd.py +++ b/src/cdtools/reconstructors/sgd.py @@ -49,11 +49,17 @@ def __init__(self, dataset: Ptycho2DDataset, subset: List[int] = None): - super().__init__(model, dataset, subset) - # Define the optimizer for use in this subclass - self.optimizer = t.optim.SGD(self.model.parameters()) + optimizer = t.optim.SGD(model.parameters()) + + super().__init__( + model, + dataset, + optimizer, + subset=subset, + ) + def adjust_optimizer(self, lr: int = 0.005, momentum: float = 0, @@ -88,7 +94,7 @@ def adjust_optimizer(self, def optimize(self, iterations: int, - batch_size: int = None, + batch_size: int = 15, lr: float = 2e-7, momentum: float = 0, dampening: float = 0, @@ -139,25 +145,20 @@ def optimize(self, Optional, enable/disable shuffling of the dataset. This option is intended for diagnostic purposes and should be left as True. """ - # 1) The subset statement is contained in Reconstructor.__init__ - - # 2) Set up / re-initialize the data laoder - if batch_size is not None: - self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) - else: - # Use default torch dataloader parameters - self.setup_dataloader(batch_size=1, shuffle=False) - # 3) The optimizer is created in self.__init__, but the - # hyperparameters need to be set up with self.adjust_optimizer + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer self.adjust_optimizer(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) - # 4) This is analagous to making a call to CDIModel.AD_optimize - return super(SGDReconstructor, self).optimize(iterations, - regularization_factor, - thread, - calculation_width) + # Now, we run the optimize routine defined in the base class + return super(SGDReconstructor, self).optimize( + iterations, + batch_size=batch_size, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width, + ) From 3c8de5dc192072fdc5e0753d00183e3c31cca41f Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Thu, 16 Oct 2025 14:16:12 +0200 Subject: [PATCH 09/13] Update the example codes --- examples/fancy_ptycho.py | 20 ++++++++++++++++---- examples/gold_ball_ptycho.py | 11 +++++++---- examples/gold_ball_split.py | 9 ++++++--- examples/simple_ptycho.py | 11 ++++++++++- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 1cf58550..067f78de 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -19,19 +19,31 @@ model.to(device=device) dataset.get_as(device=device) +# Now, to do the reconstruction we will use the more flexible pattern of +# creating an explicit reconstructor. This is what is used behind the +# scenes by the convenience functions model._optimize. For a long +# reconstruction script with multiple steps, it is better to create the +# reconstructor explicitly. +# +# The reconstructor will store the model and dataset and create an appropriate +# optimizer. This allows the optimizer to persist, along with e.g. estimates +# of the moments of individual parameters between loops +recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # The learning rate parameter sets the alpha for Adam. # The beta parameters are (0.9, 0.999) by default # The batch size sets the minibatch size -for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): +for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) # Plotting is expensive, so we only do it every tenth epoch if model.epoch % 10 == 0: model.inspect(dataset) # It's common to chain several different reconstruction loops. Here, we -# started with an aggressive refinement to find the probe, and now we -# polish the reconstruction with a lower learning rate and larger minibatch -for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): +# started with an aggressive refinement to find the probe in the previous +# loop, and now we polish the reconstruction with a lower learning rate +# and larger minibatch +for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index 47476664..49719751 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -29,6 +29,7 @@ probe_fourier_crop=pad ) + # This is a trick that my grandmother taught me, to combat the raster grid # pathology: we randomze the our initial guess of the probe positions. # The units here are pixels in the object array. @@ -42,17 +43,20 @@ model.to(device=device) dataset.get_as(device=device) +# Create the reconstructor +recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # This will save out the intermediate results if an exception is thrown # during the reconstruction with model.save_on_exception( 'example_reconstructions/gold_balls_earlyexit.h5', dataset): - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): + for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): + for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) @@ -64,8 +68,7 @@ # Setting schedule=True automatically lowers the learning rate if # the loss fails to improve after 10 epochs - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, - schedule=True): + for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) diff --git a/examples/gold_ball_split.py b/examples/gold_ball_split.py index 8624b9f4..9fc5b083 100644 --- a/examples/gold_ball_split.py +++ b/examples/gold_ball_split.py @@ -36,15 +36,18 @@ model.to(device=device) dataset.get_as(device=device) + # Create the reconstructor + recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # For batched reconstructions like this, there's no need to live-plot # the progress - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): + for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): + for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, + for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 726af940..217d96b0 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -1,3 +1,12 @@ +""" +Runs a very simple reconstruction using the SimplePtycho model, which was +designed to be an easy introduction to show how the models are made and used. + +For a more realistic example of how to use cdtools for real-world data, +look at fancy_ptycho.py and gold_ball_ptycho.py, both of which use the +more powerful FancyPtycho model and include more information on how to +correct for common sources of error. +""" import cdtools from matplotlib import pyplot as plt @@ -13,7 +22,7 @@ model.to(device=device) dataset.get_as(device=device) -# We run the actual reconstruction +# We run the reconstruction for loss in model.Adam_optimize(100, dataset, batch_size=10): # We print a quick report of the optimization status print(model.report()) From 8328b852d393a164aa89abb2660ae242fdfd21a7 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Thu, 16 Oct 2025 14:43:04 +0200 Subject: [PATCH 10/13] Update the docs to include the reconstructors class and discuss their use in the examples --- docs/source/examples.rst | 10 ++++++++-- docs/source/index.rst | 1 + docs/source/reconstructors.rst | 5 +++++ examples/fancy_ptycho.py | 14 +++++--------- src/cdtools/reconstructors/base.py | 18 ++++++++++-------- 5 files changed, 29 insertions(+), 19 deletions(-) create mode 100644 docs/source/reconstructors.rst diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 517f1b6c..a3ef41eb 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small. -Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at every epoch, to allow some monitoring code to be run. +Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. -Finally, the results can be studied using :code:`model.inspect(dataet)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. +Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. Fancy Ptycho @@ -63,6 +63,12 @@ By default, FancyPtycho will also optimize over the following model parameters, These corrections can be turned off (on) by calling :code:`model..requires_grad = False #(True)`. +Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`. + +We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. + +In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`). + Gold Ball Ptycho ---------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index c6fcb8dc..b8cd288d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,6 +9,7 @@ general datasets models + reconstructors tools/index indices_tables diff --git a/docs/source/reconstructors.rst b/docs/source/reconstructors.rst new file mode 100644 index 00000000..2e78a747 --- /dev/null +++ b/docs/source/reconstructors.rst @@ -0,0 +1,5 @@ +Reconstructors +============== + +.. automodule:: cdtools.reconstructors + :members: diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 067f78de..ebee68f0 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -19,15 +19,11 @@ model.to(device=device) dataset.get_as(device=device) -# Now, to do the reconstruction we will use the more flexible pattern of -# creating an explicit reconstructor. This is what is used behind the -# scenes by the convenience functions model._optimize. For a long -# reconstruction script with multiple steps, it is better to create the -# reconstructor explicitly. -# -# The reconstructor will store the model and dataset and create an appropriate -# optimizer. This allows the optimizer to persist, along with e.g. estimates -# of the moments of individual parameters between loops +# For this script, we use a slightly different pattern where we explicitly +# create a `Reconstructor` class to orchestrate the reconstruction. The +# reconstructor will store the model and dataset and create an appropriate +# optimizer. This allows the optimizer to persist between loops, along with +# e.g. estimates of the moments of individual parameters recon = cdtools.reconstructors.AdamReconstructor(model, dataset) # The learning rate parameter sets the alpha for Adam. diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index bae992ba..5177b796 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -45,14 +45,16 @@ class Reconstructor: subset : list(int) or int Optional, a pattern index or list of pattern indices to use - Important attributes: - - **model** -- Always points to the core model used. - - **optimizer** -- A `torch.optim.Optimizer` that must be defined when - initializing the Reconstructor subclass. - - **scheduler** -- A `torch.optim.lr_scheduler` that may be defined during - the `optimize` method. - - **data_loader** -- A torch.utils.data.DataLoader that is defined by - calling the `setup_dataloader` method. + Attributes + ---------- + model : CDIModel + Points to the core model used. + optimizer : torch.optim.Optimizer + Must be defined when initializing the Reconstructor subclass. + scheduler : torch.optim.lr_scheduler, optional + May be defined during the ``optimize`` method. + data_loader : torch.utils.data.DataLoader + Defined by calling the ``setup_dataloader`` method. """ def __init__(self, model: CDIModel, From f022b5e2c4ba90f9876d842b5b7407aece1cf2f5 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Thu, 16 Oct 2025 15:58:28 +0200 Subject: [PATCH 11/13] Get the tests passing again after the change of the optimization functions in the model classes --- tests/models/test_fancy_ptycho.py | 9 +++++++-- tests/test_reconstructors.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index 4182c93c..8bc4d87e 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -56,8 +56,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): dataset, n_modes=3, oversampling=2, - exponentiate_obj=True, dm_rank=2, + exponentiate_obj=True, probe_support_radius=120, propagation_distance=5e-3, units='mm', @@ -70,7 +70,7 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): model.to(device=reconstruction_device) dataset.get_as(device=reconstruction_device) - for loss in model.Adam_optimize(70, dataset, lr=0.02, batch_size=10): + for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): print(model.report()) if show_plot and model.epoch % 10 == 0: model.inspect(dataset) @@ -79,6 +79,11 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): print(model.report()) if show_plot and model.epoch % 10 == 0: model.inspect(dataset) + + for loss in model.Adam_optimize(25, dataset, lr=0.001, batch_size=50): + print(model.report()) + if show_plot and model.epoch % 10 == 0: + model.inspect(dataset) model.tidy_probes() diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 3995b6fd..2e308285 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -91,7 +91,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): ' reconstruction_device,', reconstruction_device) t.manual_seed(0) - for i, iterations in enumerate(epoch_tup): + # We only need to test the first loop to ensure it's identical + for i, iterations in enumerate(epoch_tup[:1]): for loss in model.Adam_optimize(iterations, dataset, lr=lr_tup[i], @@ -106,14 +107,15 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): model.inspect(dataset) model.compare(dataset) - # Ensure equivalency between the model reconstructions - assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) + # Ensure equivalency between the model reconstructions during the first + # pass, where they should be identical + assert np.allclose(model_recon.loss_history[:epoch_tup[0]], model.loss_history[:epoch_tup[0]]) # Ensure reconstructions have reached a certain loss tolerance. This just # comes from running a reconstruction when it was working well and # choosing a rough value. If it triggers this assertion error, something # changed to make the final quality worse! - assert model.loss_history[-1] < 0.0001 + assert model_recon.loss_history[-1] < 0.0001 @pytest.mark.slow @@ -184,7 +186,7 @@ def test_LBFGS_RPI(optical_data_ss_cxi, print('Running reconstruction using CDIModel.LBFGS_optimize.' + 'optimize on provided reconstruction_device,', reconstruction_device) t.manual_seed(0) - for i, iterations in enumerate(epoch_tup): + for i, iterations in enumerate(epoch_tup[:1]): for loss in model.LBFGS_optimize(iterations, dataset, lr=0.4, @@ -198,12 +200,12 @@ def test_LBFGS_RPI(optical_data_ss_cxi, model.compare(dataset) # Check loss equivalency between the two reconstructions - assert np.allclose(model.loss_history[-1], model_recon.loss_history[-1]) + assert np.allclose(model.loss_history[:epoch_tup[0]], model_recon.loss_history[:epoch_tup[0]]) # The final loss when testing this was 2.28607e-3. Based on this, we set # a threshold of 2.3e-3 for the tested loss. If this value has been # exceeded, the reconstructions have gotten worse. - assert model.loss_history[-1] < 0.0023 + assert model_recon.loss_history[-1] < 0.0023 @pytest.mark.slow From 9d633074d8fa509087a48f26668ec3aa7f70d60c Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 17 Oct 2025 18:37:38 -0300 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Dayne Yoshiki Sasaki <37006268+yoshikisd@users.noreply.github.com> --- src/cdtools/reconstructors/adam.py | 2 +- src/cdtools/reconstructors/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py index ac11ee53..e298bdb1 100644 --- a/src/cdtools/reconstructors/adam.py +++ b/src/cdtools/reconstructors/adam.py @@ -87,7 +87,7 @@ def optimize(self, batch_size: int = 15, lr: float = 0.005, betas: Tuple[float] = (0.9, 0.999), - custom_data_loader = None, + custom_data_loader: t.utils.data.DataLoader = None, schedule: bool = False, amsgrad: bool = False, regularization_factor: Union[float, List[float]] = None, diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index 5177b796..cf3de8c8 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -232,7 +232,7 @@ def closure(): def optimize(self, iterations: int, batch_size: int = 1, - custom_data_loader = None, + custom_data_loader: torch.utils.data.DataLoader = None, regularization_factor: Union[float, List[float]] = None, thread: bool = True, calculation_width: int = 10, From 7d390157199b90fefddea691a6d5eaba737ba9c9 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 17 Oct 2025 23:41:08 +0200 Subject: [PATCH 13/13] response to dayne's review --- tests/test_reconstructors.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py index 2e308285..132a8c27 100644 --- a/tests/test_reconstructors.py +++ b/tests/test_reconstructors.py @@ -16,7 +16,8 @@ def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): 2) We are only using the single-GPU dataloading method 3) Ensure `recon.model` points to the original `model` 4) Reconstructions performed by `Adam.optimize` and - `model.Adam_optimize` calls produce identical results. + `model.Adam_optimize` calls produce identical results when + run over one round of optimization. 5) The quality of the reconstruction remains below a specified threshold. 5) Ensure that the FancyPtycho model works fine and dandy with the @@ -129,7 +130,8 @@ def test_LBFGS_RPI(optical_data_ss_cxi, hyperparameters 2) Ensure `recon.model` points to the original `model` 3) Reconstructions performed by `LBFGS.optimize` and - `model.LBFGS_optimize` calls produce identical results. + `model.LBFGS_optimize` calls produce identical results when + run over one round of reconstruction. 4) The quality of the reconstruction remains below a specified threshold. 5) Ensure that the RPI model works fine and dandy with the @@ -216,7 +218,8 @@ def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): hyperparameters 3) Ensure `recon.model` points to the original `model` 4) Reconstructions performed by `SGD.optimize` and - `model.SGD_optimize` calls produce identical results. + `model.SGD_optimize` calls produce identical results + when run over one round of reconstruction. 5) The quality of the reconstruction remains below a specified threshold. 5) Ensure that the FancyPtycho model works fine and dandy with the