Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ When reading this script, note the basic workflow. After the data is loaded, a m

Next, the model is moved to the GPU using the :code:`model.to` function. Any device understood by :code:`torch.Tensor.to` can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the :code:`dataset.get_as` function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using :code:`dataset.to`, but the speedup is empirically quite small.

Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at every epoch, to allow some monitoring code to be run.
Once the device is selected, a reconstruction is run using :code:`model.Adam_optimize`. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run.

Finally, the results can be studied using :code:`model.inspect(dataet)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.
Finally, the results can be studied using :code:`model.inspect(dataset)`, which creates or updates a set of plots showing the current state of the model parameters. :code:`model.compare(dataset)` is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.


Fancy Ptycho
Expand Down Expand Up @@ -63,6 +63,12 @@ By default, FancyPtycho will also optimize over the following model parameters,

These corrections can be turned off (on) by calling :code:`model.<parameter>.requires_grad = False #(True)`.

Note as well two other changes that are made in this script, when compared to `simple_ptycho.py`. First, a `Reconstructor` object is explicitly created, in this case an `AdamReconstructor`. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to `Reconstructor.optimize()`.

We use this pattern, instead of the simpler call to `model.Adam_optimize()`, because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object.

In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. `model.LBFGS_optimize()`).


Gold Ball Ptycho
----------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
general
datasets
models
reconstructors
tools/index
indices_tables

Expand Down
5 changes: 5 additions & 0 deletions docs/source/reconstructors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Reconstructors
==============

.. automodule:: cdtools.reconstructors
:members:
16 changes: 12 additions & 4 deletions examples/fancy_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,27 @@
model.to(device=device)
dataset.get_as(device=device)

# For this script, we use a slightly different pattern where we explicitly
# create a `Reconstructor` class to orchestrate the reconstruction. The
# reconstructor will store the model and dataset and create an appropriate
# optimizer. This allows the optimizer to persist between loops, along with
# e.g. estimates of the moments of individual parameters
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)

# The learning rate parameter sets the alpha for Adam.
# The beta parameters are (0.9, 0.999) by default
# The batch size sets the minibatch size
for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10):
for loss in recon.optimize(50, lr=0.02, batch_size=10):
print(model.report())
# Plotting is expensive, so we only do it every tenth epoch
if model.epoch % 10 == 0:
model.inspect(dataset)

# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe, and now we
# polish the reconstruction with a lower learning rate and larger minibatch
for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50):
# started with an aggressive refinement to find the probe in the previous
# loop, and now we polish the reconstruction with a lower learning rate
# and larger minibatch
for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
Expand Down
11 changes: 7 additions & 4 deletions examples/gold_ball_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
probe_fourier_crop=pad
)


# This is a trick that my grandmother taught me, to combat the raster grid
# pathology: we randomze the our initial guess of the probe positions.
# The units here are pixels in the object array.
Expand All @@ -42,17 +43,20 @@
model.to(device=device)
dataset.get_as(device=device)

# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)

# This will save out the intermediate results if an exception is thrown
# during the reconstruction
with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):

for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
Expand All @@ -64,8 +68,7 @@

# Setting schedule=True automatically lowers the learning rate if
# the loss fails to improve after 10 epochs
for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
schedule=True):
for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
Expand Down
9 changes: 6 additions & 3 deletions examples/gold_ball_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@
model.to(device=device)
dataset.get_as(device=device)

# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)

# For batched reconstructions like this, there's no need to live-plot
# the progress
for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())

for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())

for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
for loss in recon.optimize(100, lr=0.001, batch_size=100,
schedule=True):
print(model.report())

Expand Down
11 changes: 10 additions & 1 deletion examples/simple_ptycho.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
Runs a very simple reconstruction using the SimplePtycho model, which was
designed to be an easy introduction to show how the models are made and used.

For a more realistic example of how to use cdtools for real-world data,
look at fancy_ptycho.py and gold_ball_ptycho.py, both of which use the
more powerful FancyPtycho model and include more information on how to
correct for common sources of error.
"""
import cdtools
from matplotlib import pyplot as plt

Expand All @@ -13,7 +22,7 @@
model.to(device=device)
dataset.get_as(device=device)

# We run the actual reconstruction
# We run the reconstruction
for loss in model.Adam_optimize(100, dataset, batch_size=10):
# We print a quick report of the optimization status
print(model.report())
Expand Down
4 changes: 2 additions & 2 deletions src/cdtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
warnings.filterwarnings("ignore",
message='To copy construct from a tensor, ')

__all__ = ['tools', 'datasets', 'models']
__all__ = ['tools', 'datasets', 'models', 'reconstructors']

from cdtools import tools
from cdtools import datasets
from cdtools import models

from cdtools import reconstructors
Loading