-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtutorial.html
More file actions
505 lines (451 loc) · 62.5 KB
/
tutorial.html
File metadata and controls
505 lines (451 loc) · 62.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="./">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Tutorial — CDTools 0.2.0 documentation</title>
<link rel="stylesheet" type="text/css" href="_static/pygments.css?v=fa44fd50" />
<link rel="stylesheet" type="text/css" href="_static/css/theme.css?v=19f00094" />
<!--[if lt IE 9]>
<script src="_static/js/html5shiv.min.js"></script>
<![endif]-->
<script src="_static/jquery.js?v=5d32c60e"></script>
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="_static/documentation_options.js?v=938c9ccc"></script>
<script src="_static/doctools.js?v=9a2dae69"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/js/theme.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="General Reference" href="general.html" />
<link rel="prev" title="Examples" href="examples.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="index.html" class="icon icon-home">
CDTools
</a>
<div class="version">
0.2
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="index.html">Introduction to CDTools</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="examples.html">Examples</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Tutorial</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#datasets">Datasets</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#basic-idea">Basic Idea</a></li>
<li class="toctree-l3"><a class="reference internal" href="#writing-the-skeleton">Writing the Skeleton</a></li>
<li class="toctree-l3"><a class="reference internal" href="#initialization">Initialization</a></li>
<li class="toctree-l3"><a class="reference internal" href="#dataset-interface">Dataset Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="#loading-and-saving">Loading and Saving</a></li>
<li class="toctree-l3"><a class="reference internal" href="#inspecting">Inspecting</a></li>
<li class="toctree-l3"><a class="reference internal" href="#notes">Notes</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="#models">Models</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#id1">Basic Idea</a></li>
<li class="toctree-l3"><a class="reference internal" href="#id2">Writing the Skeleton</a></li>
<li class="toctree-l3"><a class="reference internal" href="#initialization-from-python">Initialization from Python</a></li>
<li class="toctree-l3"><a class="reference internal" href="#initialization-from-dataset">Initialization from Dataset</a></li>
<li class="toctree-l3"><a class="reference internal" href="#the-forward-model">The Forward Model</a></li>
<li class="toctree-l3"><a class="reference internal" href="#plotting">Plotting</a></li>
<li class="toctree-l3"><a class="reference internal" href="#saving">Saving</a></li>
<li class="toctree-l3"><a class="reference internal" href="#testing">Testing</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="general.html">General Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="datasets.html">Datasets</a></li>
<li class="toctree-l1"><a class="reference internal" href="models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="tools/index.html">Tools</a></li>
<li class="toctree-l1"><a class="reference internal" href="indices_tables.html">Indices and tables</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">CDTools</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Tutorial</li>
<li class="wy-breadcrumbs-aside">
<a href="_sources/tutorial.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="tutorial">
<h1>Tutorial<a class="headerlink" href="#tutorial" title="Link to this heading"></a></h1>
<p>The following tutorial gives a peek under the hood, and is intended for someone who might want to write their own variant of a ptychography model or modify an existing model to meet a specific need. If you just need to use CDTools for a reconstruction, or are just starting to work with the code, the examples section is a great first introduction.</p>
<p>In the first section of the tutorial, we will discuss how the datasets are defined and go through the process of defining a new dataset type. Following that, we will go through the process of defining a simplified model for standard ptychography.</p>
<section id="datasets">
<h2>Datasets<a class="headerlink" href="#datasets" title="Link to this heading"></a></h2>
<p>In this section, we will write a bare-bones dataset class for 2D ptychography data to demonstrate the process of writing a new dataset class. At the end of the tutorial, we will have written the file examples/tutorial_basic_ptycho_dataset.py, which can be consulted for reference.</p>
<section id="basic-idea">
<h3>Basic Idea<a class="headerlink" href="#basic-idea" title="Link to this heading"></a></h3>
<p>At it’s core, a dataset object for CDTools is just an object that implements the dataset interface from pytorch. For this reason, the base class (<code class="code docutils literal notranslate"><span class="pre">CDataset</span></code>) from which all the datasets are defined is itself a subclass of <code class="code docutils literal notranslate"><span class="pre">torch.utils.data.Dataset</span></code>. In addition, CDataset implements an extra layer that allows for a separation between the device (CPU or GPU) that the data is stored on and the device that it returns data on. This allows for GPU-based reconstructions on datasets that are too large to fit into the GPU in their entirety.</p>
<p>The pytorch Dataset interface is very simple. A dataset simply has to define two functions, <code class="code docutils literal notranslate"><span class="pre">__len__()</span></code> and <code class="code docutils literal notranslate"><span class="pre">__getitem__()</span></code>. Thus, we can always access the data in a Dataset <code class="code docutils literal notranslate"><span class="pre">mydata</span></code> using the syntax <code class="code docutils literal notranslate"><span class="pre">mydata[index]</span></code> or <code class="code docutils literal notranslate"><span class="pre">mydata[slice]</span></code>. Overriding these functions will be the first task in defining a new dataset.</p>
<p>In CDTools datasets, the layer that allows for separation between the device that data is stored on and the device that data is loaded onto is implemented in the <code class="code docutils literal notranslate"><span class="pre">__getitem__()</span></code> function. Instead of overriding this function directly, one should override the <code class="code docutils literal notranslate"><span class="pre">_load()</span></code> function, which is used internally by <code class="code docutils literal notranslate"><span class="pre">__getitem__()</span></code>.</p>
<p>In addition to acting as a pytorch Dataset, CDTools Datasets also work as interfaces to .cxi files. Therefore, when writing a new dataset, it is important to also override the functions <code class="code docutils literal notranslate"><span class="pre">to_cxi()</span></code> and <code class="code docutils literal notranslate"><span class="pre">from_cxi()</span></code> which handle writing to and reading from cxi files, respectively.</p>
<p>The final piece of the puzzle is the <code class="code docutils literal notranslate"><span class="pre">inspect()</span></code> method. This is not required to be defined for all datasets, however it is extremely valuable to offer a simple way of exploring a dataset visually. Therefore it is highly recommended to implement this function, which should load a plot or interactive plot that allows a user to visualize the data that they have loaded.</p>
</section>
<section id="writing-the-skeleton">
<h3>Writing the Skeleton<a class="headerlink" href="#writing-the-skeleton" title="Link to this heading"></a></h3>
<p>We can start with the basic skeleton for this file. In addition to our standard imports, we also import the base CDataset class and the data tools. We then define an <code class="code docutils literal notranslate"><span class="pre">__all__</span></code> list as good practice, and set up the inheritance of our class.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span> <span class="k">as</span> <span class="nn">t</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">cdtools.datasets</span> <span class="kn">import</span> <span class="n">CDataset</span>
<span class="kn">from</span> <span class="nn">cdtools.tools</span> <span class="kn">import</span> <span class="n">data</span> <span class="k">as</span> <span class="n">cdtdata</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'BasicPtychoDataset'</span><span class="p">]</span>
<span class="k">class</span> <span class="nc">BasicPtychoDataset</span><span class="p">(</span><span class="n">CDataset</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""The standard dataset for a 2D ptychography scan"""</span>
<span class="k">pass</span>
</pre></div>
</div>
</section>
<section id="initialization">
<h3>Initialization<a class="headerlink" href="#initialization" title="Link to this heading"></a></h3>
<p>The next thing to implement is the initialization code. Here we can leverage some of the work already done in the base CDataset class. There are a number of kinds of metadata that can be stored in a .cxi file that aren’t related to the kind of experiment you’re performing - sample ID, start and end times, and so on. The CDataset’s initialization routine handles loading and storing these various kinds of metadata, so we can start the definition of our initialization routine by leveraging this:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</pre></div>
</div>
<p>Of course, there is also some data that are unique to this kind of dataset. In this case, those data are the probe translations and the measured diffraction patterns. Therefore, we extend this definition to the following:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">translations</span><span class="p">,</span> <span class="n">patterns</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Initialize the dataset from python objects"""</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">translations</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">translations</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">patterns</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">patterns</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
</pre></div>
</div>
</section>
<section id="dataset-interface">
<h3>Dataset Interface<a class="headerlink" href="#dataset-interface" title="Link to this heading"></a></h3>
<p>The next set of functions to write are those that plug into the dataset interface. We want <code class="code docutils literal notranslate"><span class="pre">len(dataset)</span></code> to return the number of diffraction patterns, which is straightforward to implement.</p>
<p>For the <code class="code docutils literal notranslate"><span class="pre">_load()</span></code> implementation, we need to consider what format the data should be returned in. The standard for all CDTools datasets is to return a tuple of (inputs, output). The inputs should always be defined as a tuple of inputs, even if there is only one input for this kind of data. As we will see later in the section on constructing models, this makes it possible to write the automatic differentiation code in a way that is applicable to every model.</p>
<p>In this case, our “inputs” will be a tuple of (pattern index, probe translation). This is not the only reasonable choice - it would also be possible, for example to define the input as just a pattern index (and store the probe translations in the model). For simple ptychography models with no error correction, it’s also possible to just take a probe translation as an input with no index. Requiring both is the compromise that’s been implemented in the default ptychography models defined with CDTools, and therefore we will follow that format here.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">patterns</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">_load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
<span class="k">return</span> <span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">translations</span><span class="p">[</span><span class="n">index</span><span class="p">]),</span> <span class="bp">self</span><span class="o">.</span><span class="n">patterns</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
</pre></div>
</div>
<p>Remember that it’s not needed to worry about what device or datatype the data is stored as here, as the relevant conversions will be performed by the <code class="code docutils literal notranslate"><span class="pre">__getitem()</span></code> method defined in the superclass. However, we do generally implement a method, <code class="code docutils literal notranslate"><span class="pre">to()</span></code>, that moves the data back and forth between devices and datatypes. This lets a user speed up data loading onto the GPU by preloading the data, for example - provided there is enough space.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">to</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Sends the relevant data to the given device and dtype"""</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">translations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">translations</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">patterns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patterns</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
</pre></div>
</div>
<p>Here we can see that we first make sure to call the superclass function to handle sending any information (such as a pixel mask, or detector background) that would have been defined in CDataset to the relevant device. Then we handle the new objects that are defined specifically for this kind of dataset.</p>
</section>
<section id="loading-and-saving">
<h3>Loading and Saving<a class="headerlink" href="#loading-and-saving" title="Link to this heading"></a></h3>
<p>Now we turn to writing the tools to load and save data. First, to load the data, we override <code class="code docutils literal notranslate"><span class="pre">from_cxi()</span></code>, which is a factory method. In this case, we start by using the superclass to load the metadata. Then we explicitly load in and add the data that’s specific to this dataset class</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_cxi</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cxi_file</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Generates a new CDataset from a .cxi file directly"""</span>
<span class="c1"># Generate a base dataset</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">CDataset</span><span class="o">.</span><span class="n">from_cxi</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">)</span>
<span class="c1"># Mutate the class to this subclass (BasicPtychoDataset)</span>
<span class="n">dataset</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="c1"># Load the data that is only relevant for this class</span>
<span class="n">patterns</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">cdtdata</span><span class="o">.</span><span class="n">get_data</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">)</span>
<span class="n">translations</span> <span class="o">=</span> <span class="n">cdtdata</span><span class="o">.</span><span class="n">get_ptycho_translations</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">)</span>
<span class="c1"># And now re-add it</span>
<span class="n">dataset</span><span class="o">.</span><span class="n">translations</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">translations</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">dataset</span><span class="o">.</span><span class="n">patterns</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">patterns</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">return</span> <span class="n">dataset</span>
</pre></div>
</div>
<p>Now to save the data, we override <code class="code docutils literal notranslate"><span class="pre">to_cxi()</span></code>, in a fairly self-explanatory way.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">to_cxi</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cxi_file</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Saves out a BasicPtychoDataset as a .cxi file"""</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">to_cxi</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">)</span>
<span class="n">cdtdata</span><span class="o">.</span><span class="n">add_data</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patterns</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axes</span><span class="p">)</span>
<span class="n">cdtdata</span><span class="o">.</span><span class="n">add_ptycho_translations</span><span class="p">(</span><span class="n">cxi_file</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">translations</span><span class="p">)</span>
</pre></div>
</div>
<p>Note that these functions should be defined to work on h5py file objects representing the .cxi files (.cxi files are just .h5 files with a special formatting).</p>
</section>
<section id="inspecting">
<h3>Inspecting<a class="headerlink" href="#inspecting" title="Link to this heading"></a></h3>
<p>The final piece of the puzzle is writing a function to look at your data! This is an important thing to work on for a dataset class that you intend to use regularly, as being able to easily peruse your raw data has incalculable value. Here, we satisfy ourselves with just plotting a random diffraction pattern.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">inspect</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""Plots a random diffraction pattern"""</span>
<span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">patterns</span><span class="p">[</span><span class="n">index</span><span class="p">,:,:]</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</pre></div>
</div>
</section>
<section id="notes">
<h3>Notes<a class="headerlink" href="#notes" title="Link to this heading"></a></h3>
<p>This is a bare-bones class, set up to demonstrate the minimum neccessary to develop a new type of dataset class. As a result, it doesn’t implement a number of things that are useful or valuable in practice (and which the default Ptycho2DDataset does implement). That includes a useful data inspector, the ability to load datasets directly from filenames, and default tweaks to how metadata such as backgrounds and masks are loaded.</p>
</section>
</section>
<section id="models">
<h2>Models<a class="headerlink" href="#models" title="Link to this heading"></a></h2>
<p>In this section, we will write a basic model for 2D ptychography reconstructions. At the end of this tutorial, we will have written the class defined in examples/tutorial_simple_ptycho_model.py</p>
<section id="id1">
<h3>Basic Idea<a class="headerlink" href="#id1" title="Link to this heading"></a></h3>
<p>Just like CDTools Datasets subclass pytorch Datasets, CDTools models subclass pytorch modules. However, the concept of a CDTools model does differ slightly from that of a pytorch module, because the CDTools models also contain a few standard methods to run automatic differentiation reconstructons on themselves.</p>
<p>This isn’t necessarily the cleanest or most portable approach, but we’ve found that it feels very natural from the perspective of an end user interacting with the toolbox only through the reconstruction scripts.</p>
<p>The heart of each model is a <code class="code docutils literal notranslate"><span class="pre">model.forward()</span></code> function. In any CDTools model, this forward function maps a set of parameters describing the specific diffraction pattern to simulate to the simulated result. When it’s paired with an appropriate dataset for a reconstruction, it maps from the “inputs” defined by the dataset to the “outputs”.</p>
<p>For ptychography, this information is usually index of the exposure within the dataset (which is used to retrieve exposure-to-exposure information, like probe intensity factors) and the object translation.</p>
<p>A simple forward model is defined in the top level <code class="code docutils literal notranslate"><span class="pre">CDIModel</span></code> class from which all other models are derived, and rarely needs to be overridden. The definition is quite simple:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">measurement</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">forward_propagator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interaction</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)))</span>
</pre></div>
</div>
<p>So we can see that to fully implement this forward model, we have to define the three functions <code class="code docutils literal notranslate"><span class="pre">model.interaction()</span></code>, <code class="code docutils literal notranslate"><span class="pre">model.forward_propagator()</span></code>, and <code class="code docutils literal notranslate"><span class="pre">model.measurement()</span></code>, which simulate conceptual stages in the diffraction process.</p>
<p>In addition to the core model definition, a few other functions need to be defined to make the model useful. The model needs an <em>initializer</em> to create itself from a dataset, it must have an appropriate <em>loss function</em> defined for use with automtic differentiation, a way of plotting the progress of a reconstruction, and must know how to save the results of a reconstruction in a useful format.</p>
</section>
<section id="id2">
<h3>Writing the Skeleton<a class="headerlink" href="#id2" title="Link to this heading"></a></h3>
<p>Once again, we start with the basic skeleton</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span> <span class="k">as</span> <span class="nn">t</span>
<span class="kn">from</span> <span class="nn">cdtools.models</span> <span class="kn">import</span> <span class="n">CDIModel</span>
<span class="kn">from</span> <span class="nn">cdtools</span> <span class="kn">import</span> <span class="n">tools</span>
<span class="kn">from</span> <span class="nn">cdtools.tools</span> <span class="kn">import</span> <span class="n">plotting</span> <span class="k">as</span> <span class="n">p</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'SimplePtycho'</span><span class="p">]</span>
<span class="k">class</span> <span class="nc">SimplePtycho</span><span class="p">(</span><span class="n">CDIModel</span><span class="p">):</span>
<span class="w"> </span><span class="sd">"""A simple ptychography model to demonstrate the structure of a model</span>
<span class="sd"> """</span>
</pre></div>
</div>
<p>Note that we imported the full tools package, as we will find ourselves using many low-level functions defined there to implement the model.</p>
</section>
<section id="initialization-from-python">
<h3>Initialization from Python<a class="headerlink" href="#initialization-from-python" title="Link to this heading"></a></h3>
<p>Two initialization functions need to be written. First, we write the <code class="code docutils literal notranslate"><span class="pre">__init__()</span></code> function, which initializes the model from a collection of python objects describing the system. We then write an initializer that creates a model using a dataset to initialize the various parameters.</p>
<p>There is no requirement for what the arguments to the initialization function of any particular model should be, only that they contain enough information to run the simulations! It should be chosen in a model-by-model basis to allow for the most sensible code.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">wavelength</span><span class="p">,</span>
<span class="n">probe_basis</span><span class="p">,</span>
<span class="n">probe_guess</span><span class="p">,</span>
<span class="n">obj_guess</span><span class="p">,</span>
<span class="n">min_translation</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">],</span>
<span class="p">):</span>
<span class="c1"># We initialize the superclass</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># We register all the constants, like wavelength, as buffers. This</span>
<span class="c1"># lets the model hook into some nice pytorch features, like using</span>
<span class="c1"># model.to, and broadcasting the model state across multiple GPUs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'wavelength'</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">wavelength</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'min_translation'</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">min_translation</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'probe_basis'</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">probe_basis</span><span class="p">))</span>
<span class="c1"># We cast the probe and object to 64-bit complex tensors</span>
<span class="n">probe_guess</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">probe_guess</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">complex64</span><span class="p">)</span>
<span class="n">obj_guess</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="n">obj_guess</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">complex64</span><span class="p">)</span>
<span class="c1"># We rescale the probe here so it learns at the same rate as the</span>
<span class="c1"># object when using optimizers, like Adam, which set the stepsize</span>
<span class="c1"># to a fixed maximum</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">'probe_norm'</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">probe_guess</span><span class="p">)))</span>
<span class="c1"># And we store the probe and object guesses as parameters, so</span>
<span class="c1"># they can get optimized by pytorch</span>
<span class="bp">self</span><span class="o">.</span><span class="n">probe</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">probe_guess</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_norm</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">obj</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">obj_guess</span><span class="p">)</span>
</pre></div>
</div>
<p>The first thing to notice about this model is that all the fixed, geometric information is stored with the <code class="code docutils literal notranslate"><span class="pre">module.register_buffer()</span></code> function. This is what makes it possible to move all the relevant tensors between devices using a single call to <code class="code docutils literal notranslate"><span class="pre">module.to()</span></code>, for example. It stores thetensor as an object attribute, but it also registers it so that pytorch is aware that this attribute helps to encode the state of the model.</p>
<p>The supporting information we need is the wavelength of the illumination, the basis of the probe array in real space, and an offset to define the zero point of the translation.</p>
<p>The final two pieces of information that we need to save are the probe and object, and both of these get defined as <code class="code docutils literal notranslate"><span class="pre">t.nn.Parameter</span></code> objects instead of Tensors. As a result, they get registered as parameters in the pytorch module, and will therefore be optimized over in any later reconstructions. In addition, the <code class="code docutils literal notranslate"><span class="pre">requires_grad</span></code> flag is set to <code class="code docutils literal notranslate"><span class="pre">True</span></code>, which means that the information needed for gradient calculations will be stored on every Tensor that results from a calculation including a Parameter.</p>
<p>A list of all parameters associated with the module can be found by calling <code class="code docutils literal notranslate"><span class="pre">module.parameters()</span></code>.</p>
<p>Any additional targets of reconstruction - such as exposure-to-exposure illumination weights, translation offsets, or a detector background - would be added to the model as a parameter in a similar way.</p>
<p>One final note is that we actually store a scaled version of the probe. This is a specific case of a general policy designed around making it easy to use the Adam optimizer.</p>
<p>The Adam optimizer is designed so that the learning rate sets the maximum stepsize which will be taken in any single iteration. Therefore, it is important to make sure that <em>all parameters of the model are of order unity</em>. To enable this, we scale the probe so that the typical pixel value within the probe array is of order 1.</p>
<p>This is important to remember when adding additional error models. Rescaling all the parameters to have a typical amplitude near 1 is the best way to get well-behaved reconstructions.</p>
</section>
<section id="initialization-from-dataset">
<h3>Initialization from Dataset<a class="headerlink" href="#initialization-from-dataset" title="Link to this heading"></a></h3>
<p>To initialize the object from a dataset, we need to start by extracting the relevant information from the dataset, before calling the constructor we defined above:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">from_dataset</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="c1"># We get the key geometry information from the dataset</span>
<span class="n">wavelength</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">wavelength</span>
<span class="n">det_basis</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">detector_geometry</span><span class="p">[</span><span class="s1">'basis'</span><span class="p">]</span>
<span class="n">det_shape</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">distance</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">detector_geometry</span><span class="p">[</span><span class="s1">'distance'</span><span class="p">]</span>
<span class="c1"># Then, we generate the probe geometry</span>
<span class="n">ewg</span> <span class="o">=</span> <span class="n">tools</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">exit_wave_geometry</span>
<span class="n">probe_basis</span> <span class="o">=</span> <span class="n">ewg</span><span class="p">(</span><span class="n">det_basis</span><span class="p">,</span> <span class="n">det_shape</span><span class="p">,</span> <span class="n">wavelength</span><span class="p">,</span> <span class="n">distance</span><span class="p">)</span>
<span class="c1"># Next generate the object geometry from the probe geometry and</span>
<span class="c1"># the translations</span>
<span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="n">translations</span><span class="p">),</span> <span class="n">patterns</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[:]</span>
<span class="n">pix_translations</span> <span class="o">=</span> <span class="n">tools</span><span class="o">.</span><span class="n">interactions</span><span class="o">.</span><span class="n">translations_to_pixel</span><span class="p">(</span>
<span class="n">probe_basis</span><span class="p">,</span>
<span class="n">translations</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">obj_size</span><span class="p">,</span> <span class="n">min_translation</span> <span class="o">=</span> <span class="n">tools</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">calc_object_setup</span><span class="p">(</span>
<span class="n">det_shape</span><span class="p">,</span>
<span class="n">pix_translations</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Finally, initialize the probe and object using this information</span>
<span class="n">probe</span> <span class="o">=</span> <span class="n">tools</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">SHARP_style_probe</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">obj</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">obj_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">complex64</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">wavelength</span><span class="p">,</span>
<span class="n">probe_basis</span><span class="p">,</span>
<span class="n">probe</span><span class="p">,</span>
<span class="n">obj</span><span class="p">,</span>
<span class="n">min_translation</span><span class="o">=</span><span class="n">min_translation</span>
<span class="p">)</span>
</pre></div>
</div>
<p>Here, we start by pulling the basic geometric information from the dataset. Then, we use a number of the basic tools to do calculations such as finding the probe basis from the detector geometry, or calculating how big our object array should be.</p>
<p>Once we have the basic setup ready, we then use one of the initialization functions - in this case, <code class="code docutils literal notranslate"><span class="pre">tools.initializers.SHARP_style_probe</span></code>, to find a sensible initialization for the probe. This particular initialization is based on the approach used in the SHARP package, where the square-root of the mean diffraction pattern intensity is used to estimate the structure of the illumination at focus.</p>
<p>Once all the needed information has been collected, we initialize the object.</p>
</section>
<section id="the-forward-model">
<h3>The Forward Model<a class="headerlink" href="#the-forward-model" title="Link to this heading"></a></h3>
<p>First, we have to implement the interaction model, as below:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">interaction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">,</span> <span class="n">translations</span><span class="p">):</span>
<span class="c1"># We map from real-space to pixel-space units</span>
<span class="n">pix_trans</span> <span class="o">=</span> <span class="n">tools</span><span class="o">.</span><span class="n">interactions</span><span class="o">.</span><span class="n">translations_to_pixel</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="p">,</span>
<span class="n">translations</span><span class="p">)</span>
<span class="n">pix_trans</span> <span class="o">-=</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_translation</span>
<span class="c1"># This function extracts the appropriate window from the object and</span>
<span class="c1"># multiplies the object and probe functions</span>
<span class="k">return</span> <span class="n">tools</span><span class="o">.</span><span class="n">interactions</span><span class="o">.</span><span class="n">ptycho_2D_round</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">probe_norm</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">obj</span><span class="p">,</span>
<span class="n">pix_trans</span><span class="p">)</span>
</pre></div>
</div>
<p>Here, we take input in the form of an index and a translation. Note that this input format much match the format that is output by the associated datasets that we will use for reconstruction, in this case BasicPtychoDataset.</p>
<p>We start by mapping the translation, given in real space, into pixel coordinates. Then, we use an “off-the-shelf” interaction model - <code class="code docutils literal notranslate"><span class="pre">ptycho_2d_round</span></code>, which models a standard 2D ptychography interaction, but rounds the translations to the nearest whole pixel (does not attempt subpixel translations).</p>
<p>The next three definitions amount to just choosing an off-the-shelf function to simulate each step in the chain.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward_propagator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">wavefields</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tools</span><span class="o">.</span><span class="n">propagators</span><span class="o">.</span><span class="n">far_field</span><span class="p">(</span><span class="n">wavefields</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">measurement</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">wavefields</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tools</span><span class="o">.</span><span class="n">measurements</span><span class="o">.</span><span class="n">intensity</span><span class="p">(</span><span class="n">wavefields</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sim_data</span><span class="p">,</span> <span class="n">real_data</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tools</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">amplitude_mse</span><span class="p">(</span><span class="n">real_data</span><span class="p">,</span> <span class="n">sim_data</span><span class="p">)</span>
</pre></div>
</div>
<p>The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that exit wave to a measured pixel value, and the loss defines a loss function to attempt to minimize. The loss function we’ve chosen - the amplitude mean squared error - is the most reliable one, and can also easily be overridden by an end user.</p>
</section>
<section id="plotting">
<h3>Plotting<a class="headerlink" href="#plotting" title="Link to this heading"></a></h3>
<p>The base CDIModel class has a function, <code class="code docutils literal notranslate"><span class="pre">model.inspect()</span></code>, which looks for a class variable called <code class="code docutils literal notranslate"><span class="pre">plot_list</span></code> and plots everything contained within. The plot list should be formatted as a list of tuples, with each tuple containing:</p>
<ul class="simple">
<li><p>The title of the plot</p></li>
<li><p>A function that takes in the model and generates the relevant plot</p></li>
<li><p>Optional, a function that takes in the model and returns whether or not the plot should be generated</p></li>
</ul>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># This lists all the plots to display on a call to model.inspect()</span>
<span class="n">plot_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">(</span><span class="s1">'Probe Amplitude'</span><span class="p">,</span>
<span class="k">lambda</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fig</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">plot_amplitude</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">probe</span><span class="p">,</span> <span class="n">fig</span><span class="o">=</span><span class="n">fig</span><span class="p">,</span> <span class="n">basis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="p">)),</span>
<span class="p">(</span><span class="s1">'Probe Phase'</span><span class="p">,</span>
<span class="k">lambda</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fig</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">plot_phase</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">probe</span><span class="p">,</span> <span class="n">fig</span><span class="o">=</span><span class="n">fig</span><span class="p">,</span> <span class="n">basis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="p">)),</span>
<span class="p">(</span><span class="s1">'Object Amplitude'</span><span class="p">,</span>
<span class="k">lambda</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fig</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">plot_amplitude</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obj</span><span class="p">,</span> <span class="n">fig</span><span class="o">=</span><span class="n">fig</span><span class="p">,</span> <span class="n">basis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="p">)),</span>
<span class="p">(</span><span class="s1">'Object Phase'</span><span class="p">,</span>
<span class="k">lambda</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fig</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">plot_phase</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">obj</span><span class="p">,</span> <span class="n">fig</span><span class="o">=</span><span class="n">fig</span><span class="p">,</span> <span class="n">basis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="p">))</span>
<span class="p">]</span>
</pre></div>
</div>
<p>In this case, we’ve made use of the convenience plotting functions defined in <code class="code docutils literal notranslate"><span class="pre">tools.plotting</span></code>.</p>
</section>
<section id="saving">
<h3>Saving<a class="headerlink" href="#saving" title="Link to this heading"></a></h3>
<p>By default, a function <code class="code docutils literal notranslate"><span class="pre">model.save_results()</span></code> is defined, which returns a python dictionary with an entry, <code class="code docutils literal notranslate"><span class="pre">'state_dict'</span></code>, containing all the registered parameters and buffers in the model. It also contains a basic record of the model’s training history. This function is used internally by <code class="code docutils literal notranslate"><span class="pre">model.save_to_h5()</span></code>, as well as all other convenience functions for saving results.</p>
<p>Sometimes, it is also useful to return a more user-friendly version of the results, such as a properly rescaled version of the probe. To make this possible, <code class="code docutils literal notranslate"><span class="pre">model.save_results()</span></code> is often overridden:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">save_results</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="c1"># This will save out everything needed to recreate the object</span>
<span class="c1"># in the same state, but it's not the best formatted.</span>
<span class="n">base_results</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">save_results</span><span class="p">()</span>
<span class="c1"># So we also save out the main results in a more useable format</span>
<span class="n">probe_basis</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_basis</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">probe</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">probe</span> <span class="o">=</span> <span class="n">probe</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_norm</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">wavelength</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">wavelength</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">results</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">'probe_basis'</span><span class="p">:</span> <span class="n">probe_basis</span><span class="p">,</span>
<span class="s1">'probe'</span><span class="p">:</span> <span class="n">probe</span><span class="p">,</span>
<span class="s1">'obj'</span><span class="p">:</span> <span class="n">obj</span><span class="p">,</span>
<span class="s1">'wavelength'</span><span class="p">:</span> <span class="n">wavelength</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">return</span> <span class="p">{</span><span class="o">**</span><span class="n">base_results</span><span class="p">,</span> <span class="o">**</span><span class="n">results</span><span class="p">}</span>
</pre></div>
</div>
<p>However, it is perfectly possible to write a new ptychography model without overriding <code class="code docutils literal notranslate"><span class="pre">model.save_results()</span></code></p>
</section>
<section id="testing">
<h3>Testing<a class="headerlink" href="#testing" title="Link to this heading"></a></h3>
<p>We can test this model with a simple script, in examples/tutorial_finale.py. By filling in the backend here, we’ve been able to create a ptychography model that can be accessed and used in reconstructions via the same interface as the models we discussed in the examples section.</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">tutorial_basic_ptycho_dataset</span> <span class="kn">import</span> <span class="n">BasicPtychoDataset</span>
<span class="kn">from</span> <span class="nn">tutorial_simple_ptycho</span> <span class="kn">import</span> <span class="n">SimplePtycho</span>
<span class="kn">from</span> <span class="nn">h5py</span> <span class="kn">import</span> <span class="n">File</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="n">filename</span> <span class="o">=</span> <span class="s1">'example_data/lab_ptycho_data.cxi'</span>
<span class="k">with</span> <span class="n">File</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">dataset</span> <span class="o">=</span> <span class="n">BasicPtychoDataset</span><span class="o">.</span><span class="n">from_cxi</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">dataset</span><span class="o">.</span><span class="n">inspect</span><span class="p">()</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">SimplePtycho</span><span class="o">.</span><span class="n">from_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="s1">'mps'</span><span class="p">)</span><span class="c1">#cuda')</span>
<span class="n">dataset</span><span class="o">.</span><span class="n">get_as</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="s1">'mps'</span><span class="p">)</span><span class="c1">#cuda')</span>
<span class="k">for</span> <span class="n">loss</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">Adam_optimize</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">inspect</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">report</span><span class="p">())</span>
<span class="n">model</span><span class="o">.</span><span class="n">inspect</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">compare</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</div>
<p>Happy modeling!</p>
</section>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="examples.html" class="btn btn-neutral float-left" title="Examples" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="general.html" class="btn btn-neutral float-right" title="General Reference" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>© Copyright 2019-2024, Abraham Levitan.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>