Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/cdtools/models/fancy_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,28 @@ def get_probes(idx):
**kwargs),


def plot_translations_and_originals(self, fig, dataset):
"""Only used to make a plot for the plot list."""
p.plot_translations(
dataset.translations,
fig=fig,
units=self.units,
label='original translations',
color='#CCCCCC',
marker='o',
)
p.plot_translations(
self.corrected_translations(dataset),
fig=fig,
units=self.units,
clear_fig=False,
label='refined translations',
color='k',
marker='.'
)
plt.legend()


plot_list = [
('',
lambda self, fig, dataset: self.plot_wavefront_variation(
Expand Down Expand Up @@ -895,7 +917,7 @@ def get_probes(idx):
lambda self: self.exponentiate_obj),

('Corrected Translations',
lambda self, fig, dataset: p.plot_translations(self.corrected_translations(dataset), fig=fig, units=self.units)),
lambda self, fig, dataset: self.plot_translations_and_originals(fig, dataset)),
('Background',
lambda self, fig: p.plot_amplitude(self.background**2, fig=fig)),
('Quantum Efficiency Mask',
Expand Down
30 changes: 24 additions & 6 deletions src/cdtools/tools/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def plot_colorized(im, fig=None, basis=None, units='$\\mu$m', **kwargs):
units=units, show_cbar=False, **kwargs)


def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, **kwargs):
def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, invert_xaxis=True, clear_fig=True, label=None, color=None, marker='.', **kwargs):
"""Plots a set of probe translations in a nicely formatted way

Parameters
Expand All @@ -537,6 +537,14 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver
Whether to plot lines indicating the path taken
invert_xaxis : bool
Default is True. This flips the x axis to match the convention from .cxi files of viewing the image from the beam's perspective
clear_fig : bool
Default is True. Whether to clear the figure before plotting.
label : str
Default is None. A label to give the plotted markers for a legend.
color : str
Default is None. The color to plot the markers in. By default, will follow the matplotlib color cycle.
color : str
Default is '.'. The marker style to plot with.
\\**kwargs
All other args are passed to fig.add_subplot(111, \\**kwargs)

Expand All @@ -554,18 +562,28 @@ def plot_translations(translations, fig=None, units='$\\mu$m', lines=True, inver
ax = fig.add_subplot(111, **kwargs)
else:
plt.figure(fig.number)
plt.gcf().clear()
if clear_fig:
plt.gcf().clear()

if isinstance(translations, t.Tensor):
translations = translations.detach().cpu().numpy()

translations = translations * factor
plt.plot(translations[:,0], translations[:,1],'k.')

linestyle = '-' if lines else 'None'
linewidth = 1 if lines else 0
plt.plot(translations[:,0], translations[:,1],
marker=marker, linestyle=linestyle,
label=label, color=color,
linewidth=linewidth)

if invert_xaxis:
plt.gca().invert_xaxis()
ax = plt.gca()
x_min, x_max = ax.get_xlim()
# Protect against flipping twice if plotting on top of existing graph
if x_min <= x_max:
ax.invert_xaxis()

if lines:
plt.plot(translations[:,0], translations[:,1],'b-', linewidth=0.5)
plt.xlabel('X (' + units + ')')
plt.ylabel('Y (' + units + ')')

Expand Down