diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 517f1b6c..a3ef41eb 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small. -Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at every epoch, to allow some monitoring code to be run. +Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. -Finally, the results can be studied using :code:`model.inspect(dataet)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. +Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset. Fancy Ptycho @@ -63,6 +63,12 @@ By default, FancyPtycho will also optimize over the following model parameters, These corrections can be turned off (on) by calling :code:`model..requires_grad = False #(True)`. +Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`. + +We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object. + +In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`). + Gold Ball Ptycho ---------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index c6fcb8dc..b8cd288d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,6 +9,7 @@ general datasets models + reconstructors tools/index indices_tables diff --git a/docs/source/reconstructors.rst b/docs/source/reconstructors.rst new file mode 100644 index 00000000..2e78a747 --- /dev/null +++ b/docs/source/reconstructors.rst @@ -0,0 +1,5 @@ +Reconstructors +============== + +.. automodule:: cdtools.reconstructors + :members: diff --git a/examples/fancy_ptycho.py b/examples/fancy_ptycho.py index 1cf58550..ebee68f0 100644 --- a/examples/fancy_ptycho.py +++ b/examples/fancy_ptycho.py @@ -19,19 +19,27 @@ model.to(device=device) dataset.get_as(device=device) +# For this script, we use a slightly different pattern where we explicitly +# create a `Reconstructor` class to orchestrate the reconstruction. The +# reconstructor will store the model and dataset and create an appropriate +# optimizer. This allows the optimizer to persist between loops, along with +# e.g. estimates of the moments of individual parameters +recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # The learning rate parameter sets the alpha for Adam. # The beta parameters are (0.9, 0.999) by default # The batch size sets the minibatch size -for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): +for loss in recon.optimize(50, lr=0.02, batch_size=10): print(model.report()) # Plotting is expensive, so we only do it every tenth epoch if model.epoch % 10 == 0: model.inspect(dataset) # It's common to chain several different reconstruction loops. Here, we -# started with an aggressive refinement to find the probe, and now we -# polish the reconstruction with a lower learning rate and larger minibatch -for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50): +# started with an aggressive refinement to find the probe in the previous +# loop, and now we polish the reconstruction with a lower learning rate +# and larger minibatch +for loss in recon.optimize(50, lr=0.005, batch_size=50): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) diff --git a/examples/gold_ball_ptycho.py b/examples/gold_ball_ptycho.py index 47476664..49719751 100644 --- a/examples/gold_ball_ptycho.py +++ b/examples/gold_ball_ptycho.py @@ -29,6 +29,7 @@ probe_fourier_crop=pad ) + # This is a trick that my grandmother taught me, to combat the raster grid # pathology: we randomze the our initial guess of the probe positions. # The units here are pixels in the object array. @@ -42,17 +43,20 @@ model.to(device=device) dataset.get_as(device=device) +# Create the reconstructor +recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # This will save out the intermediate results if an exception is thrown # during the reconstruction with model.save_on_exception( 'example_reconstructions/gold_balls_earlyexit.h5', dataset): - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): + for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): + for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) @@ -64,8 +68,7 @@ # Setting schedule=True automatically lowers the learning rate if # the loss fails to improve after 10 epochs - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, - schedule=True): + for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) if model.epoch % 10 == 0: model.inspect(dataset) diff --git a/examples/gold_ball_split.py b/examples/gold_ball_split.py index 8624b9f4..9fc5b083 100644 --- a/examples/gold_ball_split.py +++ b/examples/gold_ball_split.py @@ -36,15 +36,18 @@ model.to(device=device) dataset.get_as(device=device) + # Create the reconstructor + recon = cdtools.reconstructors.AdamReconstructor(model, dataset) + # For batched reconstructions like this, there's no need to live-plot # the progress - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): + for loss in recon.optimize(20, lr=0.005, batch_size=50): print(model.report()) - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): + for loss in recon.optimize(50, lr=0.002, batch_size=100): print(model.report()) - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, + for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True): print(model.report()) diff --git a/examples/simple_ptycho.py b/examples/simple_ptycho.py index 726af940..217d96b0 100644 --- a/examples/simple_ptycho.py +++ b/examples/simple_ptycho.py @@ -1,3 +1,12 @@ +""" +Runs a very simple reconstruction using the SimplePtycho model, which was +designed to be an easy introduction to show how the models are made and used. + +For a more realistic example of how to use cdtools for real-world data, +look at fancy_ptycho.py and gold_ball_ptycho.py, both of which use the +more powerful FancyPtycho model and include more information on how to +correct for common sources of error. +""" import cdtools from matplotlib import pyplot as plt @@ -13,7 +22,7 @@ model.to(device=device) dataset.get_as(device=device) -# We run the actual reconstruction +# We run the reconstruction for loss in model.Adam_optimize(100, dataset, batch_size=10): # We print a quick report of the optimization status print(model.report()) diff --git a/src/cdtools/__init__.py b/src/cdtools/__init__.py index 96842075..91322091 100644 --- a/src/cdtools/__init__.py +++ b/src/cdtools/__init__.py @@ -4,9 +4,9 @@ warnings.filterwarnings("ignore", message='To copy construct from a tensor, ') -__all__ = ['tools', 'datasets', 'models'] +__all__ = ['tools', 'datasets', 'models', 'reconstructors'] from cdtools import tools from cdtools import datasets from cdtools import models - +from cdtools import reconstructors diff --git a/src/cdtools/models/base.py b/src/cdtools/models/base.py index b455f452..df347b63 100644 --- a/src/cdtools/models/base.py +++ b/src/cdtools/models/base.py @@ -40,6 +40,10 @@ from scipy import io from contextlib import contextmanager from cdtools.tools.data import nested_dict_to_h5, h5_to_nested_dict, nested_dict_to_numpy, nested_dict_to_torch +from cdtools.reconstructors import AdamReconstructor, LBFGSReconstructor, SGDReconstructor +from cdtools.datasets import CDataset +from typing import List, Union, Tuple +import os __all__ = ['CDIModel'] @@ -316,202 +320,23 @@ def checkpoint(self, *args): self.current_checkpoint_id += 1 - - - def AD_optimize(self, iterations, data_loader, optimizer,\ - scheduler=None, regularization_factor=None, thread=True, - calculation_width=10): - """Runs a round of reconstruction using the provided optimizer - - This is the basic automatic differentiation reconstruction tool - which all the other, algorithm-specific tools, use. It is a - generator which yields the average loss each epoch, ending after - the specified number of iterations. - - By default, the computation will be run in a separate thread. This - is done to enable live plotting with matplotlib during a reconstruction. - If the computation was done in the main thread, this would freeze - the plots. This behavior can be turned off by setting the keyword - argument 'thread' to False. - - Parameters - ---------- - iterations : int - How many epochs of the algorithm to run - data_loader : torch.utils.data.DataLoader - A data loader loading the CDataset to reconstruct - optimizer : torch.optim.Optimizer - The optimizer to run the reconstruction with - scheduler : torch.optim.lr_scheduler._LRScheduler - Optional, a learning rate scheduler to use - regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method - thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation - calculation_width : int - Default 10, how many translations to pass through at once for each round of gradient accumulation. This does not affect the result, but may affect the calculation speed. - - Yields - ------ - loss : float - The summed loss over the latest epoch, divided by the total diffraction pattern intensity - """ - - def run_epoch(stop_event=None): - """Runs one full epoch of the reconstruction.""" - # First, initialize some tracking variables - normalization = 0 - loss = 0 - N = 0 - t0 = time.time() - - # The data loader is responsible for setting the minibatch - # size, so each set is a minibatch - for inputs, patterns in data_loader: - normalization += t.sum(patterns).cpu().numpy() - N += 1 - def closure(): - optimizer.zero_grad() - - # We further break up the minibatch into a set of chunks. - # This lets us use larger minibatches than can fit - # on the GPU at once, while still doing batch processing - # for efficiency - input_chunks = [[inp[i:i + calculation_width] - for inp in inputs] - for i in range(0, len(inputs[0]), - calculation_width)] - pattern_chunks = [patterns[i:i + calculation_width] - for i in range(0, len(inputs[0]), - calculation_width)] - - total_loss = 0 - for inp, pats in zip(input_chunks, pattern_chunks): - # This check allows for graceful exit when threading - if stop_event is not None and stop_event.is_set(): - exit() - - # Run the simulation - sim_patterns = self.forward(*inp) - - # Calculate the loss - if hasattr(self, 'mask'): - loss = self.loss(pats,sim_patterns, mask=self.mask) - else: - loss = self.loss(pats,sim_patterns) - - # And accumulate the gradients - loss.backward() - total_loss += loss.detach() - - # If we have a regularizer, we can calculate it separately, - # and the gradients will add to the minibatch gradient - if regularization_factor is not None \ - and hasattr(self, 'regularizer'): - loss = self.regularizer(regularization_factor) - loss.backward() - - return total_loss - - # This takes the step for this minibatch - loss += optimizer.step(closure).detach().cpu().numpy() - - - loss /= normalization - - # We step the scheduler after the full epoch - if scheduler is not None: - scheduler.step(loss) - - self.loss_history.append(loss) - self.epoch = len(self.loss_history) - self.latest_iteration_time = time.time() - t0 - self.training_history += self.report() + '\n' - return loss - - # We store the current optimizer as a model parameter so that - # it can be saved and loaded for checkpointing - self.current_optimizer = optimizer - - # If we don't want to run in a different thread, this is easy - if not thread: - for it in range(iterations): - if self.skip_computation(): - self.epoch = self.epoch + 1 - if len(self.loss_history) >= 1: - yield self.loss_history[-1] - else: - yield float('nan') - continue - - yield run_epoch() - - - # But if we do want to thread, it's annoying: - else: - # Here we set up the communication with the computation thread - result_queue = queue.Queue() - stop_event = threading.Event() - def target(): - try: - result_queue.put(run_epoch(stop_event)) - except Exception as e: - # If something bad happens, put the exception into the - # result queue - result_queue.put(e) - - # And this actually starts and monitors the thread - for it in range(iterations): - if self.skip_computation(): - self.epoch = self.epoch + 1 - if len(self.loss_history) >= 1: - yield self.loss_history[-1] - else: - yield float('nan') - continue - - calc = threading.Thread(target=target, name='calculator', daemon=True) - try: - calc.start() - while calc.is_alive(): - if hasattr(self, 'figs'): - self.figs[0].canvas.start_event_loop(0.01) - else: - calc.join() - - except KeyboardInterrupt as e: - stop_event.set() - print('\nAsking execution thread to stop cleanly - please be patient.') - calc.join() - raise e - - res = result_queue.get() - - # If something went wrong in the thead, we'll get an exception - if isinstance(res, Exception): - raise res - - yield res - - # And finally, we unset the current optimizer: - self.current_optimizer = None - - def Adam_optimize( self, - iterations, - dataset, - batch_size=15, - lr=0.005, - betas=(0.9, 0.999), - schedule=False, - amsgrad=False, - subset=None, - regularization_factor=None, + iterations: int, + dataset: CDataset, + batch_size: int = 15, + lr: float = 0.005, + betas: Tuple[float] = (0.9, 0.999), + schedule: bool = False, + amsgrad: bool = False, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] = None, thread=True, calculation_width=10 ): - """Runs a round of reconstruction using the Adam optimizer + """ + Runs a round of reconstruction using the Adam optimizer from + cdtools.reconstructors.AdamReconstructor. This is generally accepted to be the most robust algorithm for use with ptychography. Like all the other optimization routines, @@ -521,125 +346,132 @@ def Adam_optimize( Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. batch_size : int - Optional, the size of the minibatches to use + Optional, the size of the minibatches to use. lr : float - Optional, The learning rate (alpha) to use. Defaultis 0.005. 0.05 is typically the highest possible value with any chance of being stable - betas : tuple + Optional, The learning rate (alpha) to use. Defaultis 0.005. + 0.05 is typically the highest possible value with any chance + of being stable. + betas : tuple(float) Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). - schedule : float - Optional, whether to use the ReduceLROnPlateau scheduler + schedule : bool + Optional, whether to use the ReduceLROnPlateau scheduler. + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. subset : list(int) or int Optional, a pattern index or list of pattern indices to use - regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + regularization_factor : float or list(float). + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. calculation_width : int - Default 10, how many translations to pass through at once for each round of gradient accumulation. Does not affect the result, only the calculation speed - + Default 10, how many translations to pass through at once for + each round of gradient accumulation. Does not affect the result, + only the calculation speed. + """ - - self.training_history += ( - f'Planning {iterations} epochs of Adam, with a learning rate = ' - f'{lr}, batch size = {batch_size}, regularization_factor = ' - f'{regularization_factor}, and schedule = {schedule}.\n' + reconstructor = AdamReconstructor( + model=self, + dataset=dataset, + subset=subset, ) - - if subset is not None: - # if subset is just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader - data_loader = torchdata.DataLoader(dataset, - batch_size=batch_size, - shuffle=True) - - # Define the optimizer - optimizer = t.optim.Adam( - self.parameters(), - lr = lr, + # Run some reconstructions + return reconstructor.optimize( + iterations=iterations, + batch_size=batch_size, + lr=lr, betas=betas, - amsgrad=amsgrad) - - # Define the scheduler - if schedule: - scheduler = t.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2,threshold=1e-9) - else: - scheduler = None - - return self.AD_optimize(iterations, data_loader, optimizer, - scheduler=scheduler, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) - + schedule=schedule, + amsgrad=amsgrad, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + ) - def LBFGS_optimize(self, iterations, dataset, - lr=0.1,history_size=2, subset=None, - regularization_factor=None, thread=True, - calculation_width=10, line_search_fn=None): - """Runs a round of reconstruction using the L-BFGS optimizer + def LBFGS_optimize(self, + iterations: int, + dataset: CDataset, + lr: float = 0.1, + history_size: int = 2, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] =None, + thread: bool = True, + calculation_width: int = 10, + line_search_fn: str = None): + """ + Runs a round of reconstruction using the L-BFGS optimizer from + cdtools.reconstructors.LBFGSReconstructor. This algorithm is often less stable that Adam, however in certain situations or geometries it can be shockingly efficient. Like all the other optimization routines, it is defined as a generator function which yields the average loss each epoch. - Note: There is no batch size, because it is a usually a bad idea to use + NOTE: There is no batch size, because it is a usually a bad idea to use LBFGS on anything but all the data at onece Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. lr : float - Optional, the learning rate to use + Optional, the learning rate to use. history_size : int Optional, the length of the history to use. subset : list(int) or int - Optional, a pattern index or list of pattern indices to ues + Optional, a pattern index or list of pattern indices to use. regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation. - + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. """ - if subset is not None: - # if just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader. This basically does nothing but load all the - # data at once - data_loader = torchdata.DataLoader(dataset, batch_size=len(dataset)) - - - # Define the optimizer - optimizer = t.optim.LBFGS(self.parameters(), - lr = lr, history_size=history_size, - line_search_fn=line_search_fn) - - return self.AD_optimize(iterations, data_loader, optimizer, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) - - - def SGD_optimize(self, iterations, dataset, batch_size=None, - lr=0.01, momentum=0, dampening=0, weight_decay=0, - nesterov=False, subset=None, regularization_factor=None, - thread=True, calculation_width=10): - """Runs a round of reconstruction using the SGD optimizer + reconstructor = LBFGSReconstructor( + model=self, + dataset=dataset, + subset=subset, + ) + + # Run some reconstructions + return reconstructor.optimize( + iterations=iterations, + lr=lr, + history_size=history_size, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + line_search_fn=line_search_fn, + ) + + def SGD_optimize(self, + iterations: int, + dataset: CDataset, + batch_size: int = None, + lr: float = 2e-7, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + subset: Union[int, List[int]] = None, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10): + """ + Runs a round of reconstruction using the SGD optimizer from + cdtools.reconstructors.SGDReconstructor. This algorithm is often less stable that Adam, but it is simpler and is the basic workhorse of gradience descent. @@ -647,51 +479,54 @@ def SGD_optimize(self, iterations, dataset, batch_size=None, Parameters ---------- iterations : int - How many epochs of the algorithm to run + How many epochs of the algorithm to run. dataset : CDataset - The dataset to reconstruct against + The dataset to reconstruct against. batch_size : int - Optional, the size of the minibatches to use + Optional, the size of the minibatches to use. lr : float - Optional, the learning rate to use + Optional, the learning rate to use. momentum : float Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. subset : list(int) or int - Optional, a pattern index or list of pattern indices to use + Optional, a pattern index or list of pattern indices to use. regularization_factor : float or list(float) - Optional, if the model has a regularizer defined, the set of parameters to pass the regularizer method + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. thread : bool - Default True, whether to run the computation in a separate thread to allow interaction with plots during computation + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. calculation_width : int - Default 1, how many translations to pass through at once for each round of gradient accumulation + Default 10, how many translations to pass through at once for each + round of gradient accumulation. """ + reconstructor = SGDReconstructor( + model=self, + dataset=dataset, + subset=subset, + ) - if subset is not None: - # if just one pattern, turn into a list for convenience - if type(subset) == type(1): - subset = [subset] - dataset = torchdata.Subset(dataset, subset) - - # Make a dataloader - if batch_size is not None: - data_loader = torchdata.DataLoader(dataset, batch_size=batch_size, - shuffle=True) - else: - data_loader = torchdata.DataLoader(dataset) - - - # Define the optimizer - optimizer = t.optim.SGD(self.parameters(), - lr = lr, momentum=momentum, - dampening=dampening, - weight_decay=weight_decay, - nesterov=nesterov) - - return self.AD_optimize(iterations, data_loader, optimizer, - regularization_factor=regularization_factor, - thread=thread, - calculation_width=calculation_width) + # Run some reconstructions + return reconstructor.optimize( + iterations=iterations, + batch_size=batch_size, + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + regularization_factor=regularization_factor, # noqa + thread=thread, + calculation_width=calculation_width, + ) def report(self): diff --git a/src/cdtools/reconstructors/__init__.py b/src/cdtools/reconstructors/__init__.py new file mode 100644 index 00000000..84a1a818 --- /dev/null +++ b/src/cdtools/reconstructors/__init__.py @@ -0,0 +1,22 @@ +""" +Module `cdtools.tools.reconstructors` contains the `Reconstructor` class and +subclasses which run the ptychography reconstructions on a given model and +dataset. + +The reconstructors are designed to resemble so-called 'Trainer' classes that +(in the language of the AI/ML folks) handles the 'training' of a model given +some dataset and optimizer. +""" + +# We define __all__ to be sure that import * only imports what we want +__all__ = [ + 'Reconstructor', + 'AdamReconstructor', + 'LBFGSReconstructor', + 'SGDReconstructor' +] + +from cdtools.reconstructors.base import Reconstructor +from cdtools.reconstructors.adam import AdamReconstructor +from cdtools.reconstructors.lbfgs import LBFGSReconstructor +from cdtools.reconstructors.sgd import SGDReconstructor diff --git a/src/cdtools/reconstructors/adam.py b/src/cdtools/reconstructors/adam.py new file mode 100644 index 00000000..e298bdb1 --- /dev/null +++ b/src/cdtools/reconstructors/adam.py @@ -0,0 +1,178 @@ +"""This module contains the AdamReconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +the Adam optimizer. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch as t +from typing import Tuple, List, Union +from cdtools.reconstructors import Reconstructor + +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + +__all__ = ['AdamReconstructor'] + + +class AdamReconstructor(Reconstructor): + """ + The Adam Reconstructor subclass handles the optimization ('reconstruction') + of ptychographic models and datasets using the Adam optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.Adam` to perform + optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during the + `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + # Define the optimizer for use in this subclass + optimizer = t.optim.Adam(model.parameters()) + + super().__init__(model, dataset, optimizer, subset=subset) + + + + def adjust_optimizer(self, + lr: int = 0.005, + betas: Tuple[float] = (0.9, 0.999), + amsgrad: bool = False): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + betas : tuple + Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['betas'] = betas + param_group['amsgrad'] = amsgrad + + + def optimize(self, + iterations: int, + batch_size: int = 15, + lr: float = 0.005, + betas: Tuple[float] = (0.9, 0.999), + custom_data_loader: t.utils.data.DataLoader = None, + schedule: bool = False, + amsgrad: bool = False, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + shuffle: bool = True): + """ + Runs a round of reconstruction using the Adam optimizer + + Formerly `CDIModel.Adam_optimize` + + This calls the Reconstructor.optimize superclass method + (formerly `CDIModel.AD_optimize`) to run a round of reconstruction + once the dataloader and optimizer hyperparameters have been + set up. + + The `batch_size` parameter sets the batch size for the default + dataloader. If a custom data loader is desired, it can be passed + in to the `custom_data_loader` argument, which will override the + `batch_size` and `shuffle` parameters + + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + batch_size : int + Optional, the size of the minibatches to use. + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + betas : tuple + Optional, the beta_1 and beta_2 to use. Default is (0.9, 0.999). + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + custom_data_loader : t.utils.data.DataLoader + Optional, a custom DataLoader to use. If set, will override + batch_size. + amsgrad : bool + Optional, whether to use the AMSGrad variant of this algorithm. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + # Update the training history + self.model.training_history += ( + f'Planning {iterations} epochs of Adam, with a learning rate = ' + f'{lr}, batch size = {batch_size}, regularization_factor = ' + f'{regularization_factor}, and schedule = {schedule}.\n' + ) + + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + betas=betas, + amsgrad=amsgrad) + + # Set up the scheduler + if schedule: + self.scheduler = \ + t.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, + factor=0.2, + threshold=1e-9) + else: + self.scheduler = None + + # Now, we run the optimize routine defined in the base class + return super(AdamReconstructor, self).optimize( + iterations, + batch_size=batch_size, + custom_data_loader=custom_data_loader, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width, + shuffle=shuffle, + ) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py new file mode 100644 index 00000000..cf3de8c8 --- /dev/null +++ b/src/cdtools/reconstructors/base.py @@ -0,0 +1,381 @@ +"""This module contains the base Reconstructor class for performing +optimization ('reconstructions') on ptychographic/CDI models. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. + +The subclasses of Reconstructor are required to implement +their own data loaders and optimizer adjusters +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch as t +from torch.utils import data as td +import threading +import queue +import time +from typing import List, Union + +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets import CDataset + + +__all__ = ['Reconstructor'] + + +class Reconstructor: + """ + Reconstructor handles the optimization ('reconstruction') of ptychographic + models given a CDIModel (or subclass) and corresponding CDataset. + + This is a base model that defines all functions Reconstructor subclasses + must implement. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction + dataset: CDataset + The dataset to reconstruct against + optimizer: torch.optim.Optimizer + The optimizer to use for the reconstruction + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use + + Attributes + ---------- + model : CDIModel + Points to the core model used. + optimizer : torch.optim.Optimizer + Must be defined when initializing the Reconstructor subclass. + scheduler : torch.optim.lr_scheduler, optional + May be defined during the ``optimize`` method. + data_loader : torch.utils.data.DataLoader + Defined by calling the ``setup_dataloader`` method. + """ + def __init__(self, + model: CDIModel, + dataset: CDataset, + optimizer: t.optim.Optimizer, + subset: Union[int, List[int]] = None): + + # Store parameters as attributes of Reconstructor + self.model = model + self.optimizer = optimizer + + # Store the dataset, clipping it to a subset if needed + if subset is not None: + # if subset is just one pattern, turn into a list for convenience + if isinstance(subset, int): + subset = [subset] + dataset = td.Subset(dataset, subset) + + self.dataset = dataset + + # Initialize attributes that must be defined by the subclasses + self.scheduler = None + self.data_loader = None + + + def setup_dataloader(self, + batch_size: int = None, + shuffle: bool = True): + """ + Sets up or re-initializes the dataloader. + + Parameters + ---------- + batch_size : int + Optional, the size of the minibatches to use + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + if batch_size is not None: + self.data_loader = td.DataLoader(self.dataset, + batch_size=batch_size, + shuffle=shuffle) + else: + self.data_loader = td.Dataloader(self.dataset) + + + def adjust_optimizer(self, **kwargs): + """ + Change hyperparameters for the utilized optimizer. + + For each optimizer, the keyword arguments should be manually defined + as parameters. + """ + raise NotImplementedError() + + + def run_epoch(self, + stop_event: threading.Event = None, + regularization_factor: Union[float, List[float]] = None, + calculation_width: int = 10): + """ + Runs one full epoch of the reconstruction. Intended to be called + by Reconstructor.optimize. + + Parameters + ---------- + stop_event : threading.Event + Default None, causes the reconstruction to stop when an exception + occurs in Optimizer.optimize. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. This does not affect the result, + but may affect the calculation speed. + + Returns + ------ + loss : float + The summed loss over the latest epoch, divided by the total + diffraction pattern intensity + """ + + # Setting this as an explicit catch makes me feel more comfortable + # exposing it as a public method. This way a user won't be confused + # if they try to use this directly + if self.data_loader is None: + raise RuntimeError( + 'No data loader was defined. Please run ' + 'Reconstructor.setup_dataloader() before running ' + 'Reconstructor.run_epoch(), or use Reconstructor.optimize(), ' + 'which does it automatically.' + ) + + + # Initialize some tracking variables + normalization = 0 + loss = 0 + N = 0 + t0 = time.time() + + # The data loader is responsible for setting the minibatch + # size, so each set is a minibatch + for inputs, patterns in self.data_loader: + normalization += t.sum(patterns).cpu().numpy() + N += 1 + + def closure(): + self.optimizer.zero_grad() + + # We further break up the minibatch into a set of chunks. + # This lets us use larger minibatches than can fit + # on the GPU at once, while still doing batch processing + # for efficiency + input_chunks = [[inp[i:i + calculation_width] + for inp in inputs] + for i in range(0, len(inputs[0]), + calculation_width)] + pattern_chunks = [patterns[i:i + calculation_width] + for i in range(0, len(inputs[0]), + calculation_width)] + + total_loss = 0 + + for inp, pats in zip(input_chunks, pattern_chunks): + # This check allows for graceful exit when threading + if stop_event is not None and stop_event.is_set(): + exit() + + # Run the simulation + sim_patterns = self.model.forward(*inp) + + # Calculate the loss + if hasattr(self, 'mask'): + loss = self.model.loss(pats, + sim_patterns, + mask=self.model.mask) + else: + loss = self.model.loss(pats, + sim_patterns) + + # And accumulate the gradients + loss.backward() + + # Normalize the accumulating total loss + total_loss += loss.detach() + + # If we have a regularizer, we can calculate it separately, + # and the gradients will add to the minibatch gradient + if regularization_factor is not None \ + and hasattr(self.model, 'regularizer'): + + loss = self.model.regularizer(regularization_factor) + loss.backward() + + return total_loss + + # This takes the step for this minibatch + loss += self.optimizer.step(closure).detach().cpu().numpy() + + loss /= normalization + + # We step the scheduler after the full epoch + if self.scheduler is not None: + self.scheduler.step(loss) + + self.model.loss_history.append(loss) + self.model.epoch = len(self.model.loss_history) + self.model.latest_iteration_time = time.time() - t0 + self.model.training_history += self.model.report() + '\n' + return loss + + def optimize(self, + iterations: int, + batch_size: int = 1, + custom_data_loader: torch.utils.data.DataLoader = None, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + shuffle=True): + """ + Runs a round of reconstruction using the provided optimizer + + Formerly CDIModel.AD_optimize + + This is the basic automatic differentiation reconstruction tool + which all the other, algorithm-specific tools, use. It is a + generator which yields the average loss each epoch, ending after + the specified number of iterations. + + By default, the computation will be run in a separate thread. This + is done to enable live plotting with matplotlib during a + reconstruction. + + If the computation was done in the main thread, this would freeze + the plots. This behavior can be turned off by setting the keyword + argument 'thread' to False. + + The `batch_size` parameter sets the batch size for the default + dataloader. If a custom data loader is desired, it can be passed + in to the `custom_data_loader` argument, which will override the + `batch_size` and `shuffle` parameters + + Please see `AdamReconstructor.optimize()` for an example of how to + override this function when designing a subclass + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + batch_size : int + Optional, the batch size to use. Default is 1. This is typically + overridden by subclasses with an appropriate default for the + specific optimizer. + custom_data_loader : torch.utils.data.DataLoader + Optional, a custom DataLoader to use. Will override batch_size + if set. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. This does not affect the result, + but may affect the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + + + Yields + ------ + loss : float + The summed loss over the latest epoch, divided by the total + diffraction pattern intensity. + """ + + if custom_data_loader is None: + self.setup_dataloader(batch_size=batch_size, shuffle=shuffle) + else: + self.data_loader = custom_data_loader + + # We store the current optimizer as a model parameter so that + # it can be saved and loaded for checkpointing + self.current_optimizer = self.optimizer + + # If we don't want to run in a different thread, this is easy + if not thread: + for it in range(iterations): + if self.model.skip_computation(): + self.epoch = self.epoch + 1 + if len(self.model.loss_history) >= 1: + yield self.model.loss_history[-1] + else: + yield float('nan') + continue + + yield self.run_epoch( + regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width, + ) + + # But if we do want to thread, it's annoying: + else: + # Here we set up the communication with the computation thread + result_queue = queue.Queue() + stop_event = threading.Event() + + def target(): + try: + result_queue.put( + self.run_epoch( + stop_event=stop_event, + regularization_factor=regularization_factor, # noqa + calculation_width=calculation_width, + ) + ) + except Exception as e: + # If something bad happens, put the exception into the + # result queue + result_queue.put(e) + + # And this actually starts and monitors the thread + for it in range(iterations): + if self.model.skip_computation(): + self.model.epoch = self.model.epoch + 1 + if len(self.model.loss_history) >= 1: + yield self.model.loss_history[-1] + else: + yield float('nan') + continue + + calc = threading.Thread(target=target, + name='calculator', + daemon=True) + try: + calc.start() + while calc.is_alive(): + if hasattr(self.model, 'figs'): + self.model.figs[0].canvas.start_event_loop(0.01) + else: + calc.join() + + except KeyboardInterrupt as e: + stop_event.set() + print('\nAsking execution thread to stop cleanly - ' + + 'please be patient.') + calc.join() + raise e + + res = result_queue.get() + + # If something went wrong in the thead, we'll get an exception + if isinstance(res, Exception): + raise res + + yield res + + # And finally, we unset the current optimizer: + self.current_optimizer = None diff --git a/src/cdtools/reconstructors/lbfgs.py b/src/cdtools/reconstructors/lbfgs.py new file mode 100644 index 00000000..8907782a --- /dev/null +++ b/src/cdtools/reconstructors/lbfgs.py @@ -0,0 +1,143 @@ +"""This module contains the LBFGSReconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +the LBFGS optimizer. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch as t +from typing import List, Union +from cdtools.reconstructors import Reconstructor + +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + + +__all__ = ['LBFGSReconstructor'] + + +class LBFGSReconstructor(Reconstructor): + """ + The LBFGSReconstructor subclass handles the optimization + ('reconstruction') of ptychographic models and datasets using the LBFGS + optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + schedule : bool + Optional, create a learning rate scheduler + (torch.optim.lr_scheduler._LRScheduler). + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.LBFGS` to + perform optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during + the `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + # Define the optimizer for use in this subclass + optimizer = t.optim.LBFGS(model.parameters()) + + super().__init__( + model, + dataset, + optimizer, + subset=subset, + ) + + + def adjust_optimizer(self, + lr: int = 0.005, + history_size: int = 2, + line_search_fn: str = None): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + history_size : int + Optional, the length of the history to use. + line_search_fn : str + Optional, either `strong_wolfe` or None + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['history_size'] = history_size + param_group['line_search_fn'] = line_search_fn + + def optimize(self, + iterations: int, + lr: float = 0.1, + history_size: int = 2, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + line_search_fn: str = None): + """ + Runs a round of reconstruction using the LBFGS optimizer + + Formerly `CDIModel.LBFGS_optimize` + + This algorithm is often less stable that Adam, however in certain + situations or geometries it can be shockingly efficient. Like all + the other optimization routines, it is defined as a generator + function which yields the average loss each epoch. + + NOTE: There is no batch size, because it is a usually a bad idea to use + LBFGS on anything but all the data at onece + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + lr : float + Optional, The learning rate (alpha) to use. Default is 0.1. + history_size : int + Optional, the length of the history to use. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + """ + + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + history_size=history_size, + line_search_fn=line_search_fn) + + # Now, we run the optimize routine defined in the base class + return super(LBFGSReconstructor, self).optimize( + iterations, + batch_size=len(self.dataset), + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width) diff --git a/src/cdtools/reconstructors/sgd.py b/src/cdtools/reconstructors/sgd.py new file mode 100644 index 00000000..cb1e26b1 --- /dev/null +++ b/src/cdtools/reconstructors/sgd.py @@ -0,0 +1,164 @@ +"""This module contains the SGDReconstructor subclass for performing +optimization ('reconstructions') on ptychographic/CDI models using +stochastic gradient descent. + +The Reconstructor class is designed to resemble so-called +'Trainer' classes that (in the language of the AI/ML folks) handles +the 'training' of a model given some dataset and optimizer. +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch as t +from typing import List, Union +from cdtools.reconstructors import Reconstructor + +if TYPE_CHECKING: + from cdtools.models import CDIModel + from cdtools.datasets.ptycho_2d_dataset import Ptycho2DDataset + + +__all__ = ['SGDReconstructor'] + + +class SGDReconstructor(Reconstructor): + """ + The SGDReconstructor subclass handles the optimization ('reconstruction') + of ptychographic models and datasets using the SGD optimizer. + + Parameters + ---------- + model: CDIModel + Model for CDI/ptychography reconstruction. + dataset: Ptycho2DDataset + The dataset to reconstruct against. + subset : list(int) or int + Optional, a pattern index or list of pattern indices to use. + + Important attributes: + - **model** -- Always points to the core model used. + - **optimizer** -- This class by default uses `torch.optim.Adam` to perform + optimizations. + - **scheduler** -- A `torch.optim.lr_scheduler` that is defined during the + `optimize` method. + - **data_loader** -- A torch.utils.data.DataLoader that is defined by + calling the `setup_dataloader` method. + """ + def __init__(self, + model: CDIModel, + dataset: Ptycho2DDataset, + subset: List[int] = None): + + # Define the optimizer for use in this subclass + optimizer = t.optim.SGD(model.parameters()) + + super().__init__( + model, + dataset, + optimizer, + subset=subset, + ) + + + def adjust_optimizer(self, + lr: int = 0.005, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False): + """ + Change hyperparameters for the utilized optimizer. + + Parameters + ---------- + lr : float + Optional, The learning rate (alpha) to use. Default is 0.005. 0.05 + is typically the highest possible value with any chance of being + stable. + momentum : float + Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. + """ + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + param_group['momentum'] = momentum + param_group['dampening'] = dampening + param_group['weight_decay'] = weight_decay + param_group['nesterov'] = nesterov + + def optimize(self, + iterations: int, + batch_size: int = 15, + lr: float = 2e-7, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + regularization_factor: Union[float, List[float]] = None, + thread: bool = True, + calculation_width: int = 10, + shuffle: bool = True): + """ + Runs a round of reconstruction using the Adam optimizer + + Formerly `CDIModel.Adam_optimize` + + This calls the Reconstructor.optimize superclass method + (formerly `CDIModel.AD_optimize`) to run a round of reconstruction + once the dataloader and optimizer hyperparameters have been + set up. + + Parameters + ---------- + iterations : int + How many epochs of the algorithm to run. + batch_size : int + Optional, the size of the minibatches to use. + lr : float + Optional, The learning rate to use. The default is 2e-7. + momentum : float + Optional, the length of the history to use. + dampening : float + Optional, dampening for the momentum. + weight_decay : float + Optional, weight decay (L2 penalty). + nesterov : bool + Optional, enables Nesterov momentum. Only applicable when momentum + is non-zero. + regularization_factor : float or list(float) + Optional, if the model has a regularizer defined, the set of + parameters to pass the regularizer method. + thread : bool + Default True, whether to run the computation in a separate thread + to allow interaction with plots during computation. + calculation_width : int + Default 10, how many translations to pass through at once for each + round of gradient accumulation. Does not affect the result, only + the calculation speed. + shuffle : bool + Optional, enable/disable shuffling of the dataset. This option + is intended for diagnostic purposes and should be left as True. + """ + + # The optimizer is created in self.__init__, but the + # hyperparameters need to be set up with self.adjust_optimizer + self.adjust_optimizer(lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov) + + # Now, we run the optimize routine defined in the base class + return super(SGDReconstructor, self).optimize( + iterations, + batch_size=batch_size, + regularization_factor=regularization_factor, + thread=thread, + calculation_width=calculation_width, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 850935d2..f0faea57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import pytest import torch as t - # # # The following few fixtures define some standard data files @@ -381,6 +380,18 @@ def lab_ptycho_cxi(pytestconfig): '/examples/example_data/lab_ptycho_data.cxi' +@pytest.fixture(scope='module') +def optical_data_ss_cxi(pytestconfig): + return str(pytestconfig.rootpath) + \ + '/examples/example_data/Optical_Data_ss.cxi' + + +@pytest.fixture(scope='module') +def optical_ptycho_incoherent_pickle(pytestconfig): + return str(pytestconfig.rootpath) + \ + '/examples/example_data/Optical_ptycho_incoherent.pickle' + + @pytest.fixture(scope='module') def example_nested_dicts(pytestconfig): example_tensor = t.as_tensor(np.array([1, 4.5, 7])) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index e5422d7d..8bc4d87e 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -56,8 +56,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): dataset, n_modes=3, oversampling=2, - exponentiate_obj=True, dm_rank=2, + exponentiate_obj=True, probe_support_radius=120, propagation_distance=5e-3, units='mm', @@ -70,7 +70,7 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): model.to(device=reconstruction_device) dataset.get_as(device=reconstruction_device) - for loss in model.Adam_optimize(70, dataset, lr=0.02, batch_size=10): + for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10): print(model.report()) if show_plot and model.epoch % 10 == 0: model.inspect(dataset) @@ -79,59 +79,8 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): print(model.report()) if show_plot and model.epoch % 10 == 0: model.inspect(dataset) - - model.tidy_probes() - - if show_plot: - model.inspect(dataset) - model.compare(dataset) - - # If this fails, the reconstruction has gotten worse - assert model.loss_history[-1] < 0.001 - - -@pytest.mark.slow -def test_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): - - print('\nTesting performance on the standard gold balls dataset') - - dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) - - pad = 10 - dataset.pad(pad) - - model = cdtools.models.FancyPtycho.from_dataset( - dataset, - n_modes=3, - probe_support_radius=50, - propagation_distance=2e-6, - units='um', - probe_fourier_crop=pad - ) - - model.translation_offsets.data += \ - 0.7 * t.randn_like(model.translation_offsets) - - # Not much probe intensity instability in this dataset, no need for this - model.weights.requires_grad = False - - print('Running reconstruction on provided --reconstruction_device,', - reconstruction_device) - model.to(device=reconstruction_device) - dataset.get_as(device=reconstruction_device) - - for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50): - print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) - - for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100): - print(model.report()) - if show_plot and model.epoch % 10 == 0: - model.inspect(dataset) - - for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100, - schedule=True): + + for loss in model.Adam_optimize(25, dataset, lr=0.001, batch_size=50): print(model.report()) if show_plot and model.epoch % 10 == 0: model.inspect(dataset) @@ -142,7 +91,6 @@ def test_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): model.inspect(dataset) model.compare(dataset) - # This just comes from running a reconstruction when it was working well - # and choosing a rough value. If it triggers this assertion error, - # something changed to make the final quality worse! - assert model.loss_history[-1] < 0.0001 + # If this fails, the reconstruction has gotten worse + assert model.loss_history[-1] < 0.0013 + diff --git a/tests/test_reconstructors.py b/tests/test_reconstructors.py new file mode 100644 index 00000000..132a8c27 --- /dev/null +++ b/tests/test_reconstructors.py @@ -0,0 +1,319 @@ +import pytest +import cdtools +import torch as t +import numpy as np +import pickle +from matplotlib import pyplot as plt +from copy import deepcopy + + +@pytest.mark.slow +def test_Adam_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): + """ + This test checks out several things with the Au particle dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 2) We are only using the single-GPU dataloading method + 3) Ensure `recon.model` points to the original `model` + 4) Reconstructions performed by `Adam.optimize` and + `model.Adam_optimize` calls produce identical results when + run over one round of optimization. + 5) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the FancyPtycho model works fine and dandy with the + Reconstructors. + """ + + print('\nTesting performance on the standard gold balls dataset ' + + 'with reconstructors.AdamReconstructor') + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) + pad = 10 + dataset.pad(pad) + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # Make a copy of the model + model_recon = deepcopy(model) + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with AdamReconstructor.optimize ******* + print('Running reconstruction using AdamReconstructor.optimize' + + ' on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.AdamReconstructor(model=model_recon, + dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + epoch_tup = (20, 50, 100) + lr_tup = (0.005, 0.002, 0.001) + batch_size_tup = (50, 100, 100) + + for i, iterations in enumerate(epoch_tup): + for loss in recon.optimize(iterations, + lr=lr_tup[i], + batch_size=batch_size_tup[i]): + print(model_recon.report()) + if show_plot and model_recon.epoch % 10 == 0: + model_recon.inspect(dataset) + + # Check hyperparameter update + assert recon.optimizer.param_groups[0]['lr'] == lr_tup[i] + assert recon.data_loader.batch_size == batch_size_tup[i] + + # Ensure that recon does not have sampler as an attribute (only used in + # multi-GPU) + assert not hasattr(recon, 'sampler') + + # Ensure recon.model points to the original model + assert id(model_recon) == id(recon.model) + + model_recon.tidy_probes() + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # ******* Reconstructions with CDIModel.Adam_optimize ******* + print('Running reconstruction using CDIModel.Adam_optimize on provided' + + ' reconstruction_device,', reconstruction_device) + t.manual_seed(0) + + # We only need to test the first loop to ensure it's identical + for i, iterations in enumerate(epoch_tup[:1]): + for loss in model.Adam_optimize(iterations, + dataset, + lr=lr_tup[i], + batch_size=batch_size_tup[i]): + print(model.report()) + if show_plot and model.epoch % 10 == 0: + model.inspect(dataset) + + model.tidy_probes() + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Ensure equivalency between the model reconstructions during the first + # pass, where they should be identical + assert np.allclose(model_recon.loss_history[:epoch_tup[0]], model.loss_history[:epoch_tup[0]]) + + # Ensure reconstructions have reached a certain loss tolerance. This just + # comes from running a reconstruction when it was working well and + # choosing a rough value. If it triggers this assertion error, something + # changed to make the final quality worse! + assert model_recon.loss_history[-1] < 0.0001 + + +@pytest.mark.slow +def test_LBFGS_RPI(optical_data_ss_cxi, + optical_ptycho_incoherent_pickle, + reconstruction_device, + show_plot): + """ + This test checks out several things with the transmission RPI dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 2) Ensure `recon.model` points to the original `model` + 3) Reconstructions performed by `LBFGS.optimize` and + `model.LBFGS_optimize` calls produce identical results when + run over one round of reconstruction. + 4) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the RPI model works fine and dandy with the + Reconstructors. + """ + with open(optical_ptycho_incoherent_pickle, 'rb') as f: + ptycho_results = pickle.load(f) + + probe = ptycho_results['probe'] + background = ptycho_results['background'] + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(optical_data_ss_cxi) + model = cdtools.models.RPI.from_dataset(dataset, probe, [500, 500], + background=background, n_modes=2, + initialization='random') + + # Prepare two sets of models for the comparative reconstruction + model_recon = deepcopy(model) + + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with LBFGSReconstructor.optimize ****** + print('Running reconstruction using LBFGSReconstructor.' + + 'optimize on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.LBFGSReconstructor(model=model_recon, + dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + reg_factor_tup = ([0.05, 0.05], [0.001, 0.1]) + epoch_tup = (30, 50) + for i, iterations in enumerate(epoch_tup): + for loss in recon.optimize(iterations, + lr=0.4, + regularization_factor=reg_factor_tup[i]): + if show_plot and i == 0: + model_recon.inspect(dataset) + print(model_recon.report()) + + # Check hyperparameter update (or lack thereof) + assert recon.optimizer.param_groups[0]['lr'] == 0.4 + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # Check model pointing + assert id(model_recon) == id(recon.model) + + # ******* Reconstructions with CDIModel.LBFGS_optimize ****** + print('Running reconstruction using CDIModel.LBFGS_optimize.' + + 'optimize on provided reconstruction_device,', reconstruction_device) + t.manual_seed(0) + for i, iterations in enumerate(epoch_tup[:1]): + for loss in model.LBFGS_optimize(iterations, + dataset, + lr=0.4, + regularization_factor=reg_factor_tup[i]): # noqa + if show_plot and i == 0: + model.inspect(dataset) + print(model.report()) + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Check loss equivalency between the two reconstructions + assert np.allclose(model.loss_history[:epoch_tup[0]], model_recon.loss_history[:epoch_tup[0]]) + + # The final loss when testing this was 2.28607e-3. Based on this, we set + # a threshold of 2.3e-3 for the tested loss. If this value has been + # exceeded, the reconstructions have gotten worse. + assert model_recon.loss_history[-1] < 0.0023 + + +@pytest.mark.slow +def test_SGD_gold_balls(gold_ball_cxi, reconstruction_device, show_plot): + """ + This test checks out several things with the Au particle dataset + 1) Calls to Reconstructor.adjust_optimizer is updating the + hyperparameters + 3) Ensure `recon.model` points to the original `model` + 4) Reconstructions performed by `SGD.optimize` and + `model.SGD_optimize` calls produce identical results + when run over one round of reconstruction. + 5) The quality of the reconstruction remains below a specified + threshold. + 5) Ensure that the FancyPtycho model works fine and dandy with the + Reconstructors. + + The hyperparameters used in this test are not optimized to produce + a super-high-quality reconstruction. Instead, I just need A reconstruction + to do some kind of comparative assessment. + """ + print('\nTesting performance on the standard gold balls dataset ' + + 'with reconstructors.SGDReconstructor') + + dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(gold_ball_cxi) + pad = 10 + dataset.pad(pad) + model = cdtools.models.FancyPtycho.from_dataset( + dataset, + n_modes=3, + probe_support_radius=50, + propagation_distance=2e-6, + units='um', + probe_fourier_crop=pad + ) + + model.translation_offsets.data += 0.7 * \ + t.randn_like(model.translation_offsets) + model.weights.requires_grad = False + + # Make a copy of the model + model_recon = deepcopy(model) + model.to(device=reconstruction_device) + model_recon.to(device=reconstruction_device) + dataset.get_as(device=reconstruction_device) + + # ******* Reconstructions with SGDReconstructor.optimize ******* + print('Running reconstruction using SGDReconstructor.optimize' + + ' on provided reconstruction_device,', reconstruction_device) + + recon = cdtools.reconstructors.SGDReconstructor(model=model_recon, + dataset=dataset) + t.manual_seed(0) + + # Run a reconstruction + epochs = 50 + lr = 0.00000005 + batch_size = 40 + + for loss in recon.optimize(epochs, + lr=lr, + batch_size=batch_size): + print(model_recon.report()) + if show_plot and model_recon.epoch % 10 == 0: + model_recon.inspect(dataset) + + # Check hyperparameter update + assert recon.optimizer.param_groups[0]['lr'] == lr + assert recon.data_loader.batch_size == batch_size + + # Ensure that recon does not have sampler as an attribute (only used in + # multi-GPU) + assert not hasattr(recon, 'sampler') + + # Ensure recon.model points to the original model + assert id(model_recon) == id(recon.model) + + model_recon.tidy_probes() + + if show_plot: + model_recon.inspect(dataset) + model_recon.compare(dataset) + + # ******* Reconstructions with cdtools.CDIModel.SGD_optimize ******* + print('Running reconstruction using CDIModel.SGD_optimize on provided' + + ' reconstruction_device,', reconstruction_device) + t.manual_seed(0) + + for loss in model.SGD_optimize(epochs, + dataset, + lr=lr, + batch_size=batch_size): + print(model.report()) + if show_plot and model.epoch % 10 == 0: + model.inspect(dataset) + + model.tidy_probes() + + if show_plot: + model.inspect(dataset) + model.compare(dataset) + + # Ensure equivalency between the model reconstructions + assert np.allclose(model_recon.loss_history[-1], model.loss_history[-1]) + + # The final loss when testing this was 7.12188e-4. Based on this, we set + # a threshold of 7.2e-4 for the tested loss. If this value has been + # exceeded, the reconstructions have gotten worse. + assert model.loss_history[-1] < 0.00072