Skip to content

Commit d57015b

Browse files
committed
added a doctest for the retry wrapper, and used the retry wrapper to try and make the gaussian process experiment more robust
1 parent 5817504 commit d57015b

4 files changed

Lines changed: 70 additions & 30 deletions

File tree

adbo/gp_exp.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import gpflow
99
from sithom.time import timeit
1010
from sithom.plot import plot_defaults, label_subplots, get_dim
11+
from worst.utils import retry_wrapper
1112
from .constants import DATA_PATH, EXP_PATH
1213

1314
# from tf.keras.metrics import R2Score, RootMeanSquaredError
@@ -199,6 +200,7 @@ def fit_gp(
199200

200201

201202
@timeit
203+
@retry_wrapper(max_retries=10)
202204
def run_single_fit(
203205
norm_x: bool = True,
204206
norm_y: bool = True,
@@ -298,6 +300,13 @@ def run_exp(
298300
- Fit the GP model to the training set.
299301
- Evaluate the model on the test set (log liklihood, r2, rmse).
300302
- Save the results in a csv file and a tex file.
303+
304+
Args:
305+
kernels (tuple, optional): Kernels to use. Defaults to ("Matern52", "Matern12", "SE", "RationalQuadratic").
306+
mean_functions (tuple, optional): Mean functions to use. Defaults to ("Constant", "Linear", "Polynomial", "Zero", "Exponential", "Periodic").
307+
n_train (int, optional): Number of training samples. Defaults to 25.
308+
repeats (int, optional): Number of repeats for each experiment. Defaults to 100.
309+
301310
"""
302311
results = pd.DataFrame(
303312
columns=[
@@ -370,59 +379,64 @@ def run_exp(
370379
ignore_index=True,
371380
)
372381

382+
# write plain CSV with all columns
383+
results.to_csv(
384+
os.path.join(DATA_PATH, f"gp_results_n_train_{n_train}.csv"),
385+
index=False,
386+
)
387+
save_results_tex(n_train=n_train)
388+
389+
390+
def save_results_tex(n_train: int = 25):
391+
df = pd.read_csv(os.path.join(DATA_PATH, f"gp_results_n_train_{n_train}.csv"))
392+
373393
# --- format <mean ± SEM> strings for LaTeX output, bolding the best ---
374-
pretty = results.copy()
375394

376395
# nice kernel names for LaTeX
377-
pretty["kernel"] = (
378-
pretty["kernel"]
379-
.str.replace("Matern52", r"Matern-\(\frac{5}{2}\)")
380-
.str.replace("Matern32", r"Matern-\(\frac{3}{2}\)")
381-
.str.replace("Matern12", r"Matern-\(\frac{1}{2}\)")
396+
df["kernel"] = (
397+
df["kernel"]
398+
.str.replace("Matern52", r"Mat\'ern-\({5}/{2}\)")
399+
.str.replace("Matern32", r"Mat\'ern-\({3}/{2}\)")
400+
.str.replace("Matern12", r"Mat\'ern-\({1}/{2}\)")
382401
.str.replace("SE", "SE")
383402
)
384403

385404
# find best values
386405
best_vals = {
387-
"log_likelihood": results["log_likelihood"].max(), # highest is best
388-
"rmse": results["rmse"].min(), # lowest is best
389-
"r2": results["r2"].max(), # highest is best
406+
"log_likelihood": df["log_likelihood"].max(), # highest is best
407+
"rmse": df["rmse"].min(), # lowest is best
408+
"r2": df["r2"].max(), # highest is best
390409
}
391410

392411
# helper: format mean ± sem, bold if mean equals best (allow tiny tol)
393-
def fmt_val(metric, mean, sem):
394-
base = f"{mean:.3f} $\\pm$ {sem:.3f}"
412+
def fmt_val(metric: str, mean: float, sem: float) -> str:
413+
base = f"{mean:.3f} \\(\\pm\\) {sem:.3f}"
395414
if abs(mean - best_vals[metric]) < 1e-9:
396-
return f"\\textbf{{{mean:.3f}}} \\(\\pm\\) \\textbf{{sem:.3f}}"
415+
return f"\\textbf{{{mean:.3f}}} \\(\\pm\\) \\textbf{{{sem:.3f}}}"
397416
return base
398417

399418
for metric in ("log_likelihood", "rmse", "r2"):
400419
sem_col = metric + "_sem"
401-
pretty[metric] = pretty.apply(
420+
df[metric] = df.apply(
402421
lambda row, m=metric: fmt_val(m, row[m], row[sem_col]), axis=1
403422
)
404-
pretty = pretty.drop(columns=[sem_col])
423+
df = df.drop(columns=[sem_col])
405424

406425
# rename columns for LaTeX header
407-
pretty = pretty.rename(
426+
df = df.rename(
408427
columns={
409-
"log_likelihood": r"Mean Log Likelihood \(\bar{\mathcal{L}}\) ",
428+
"log_likelihood": r"\(\bar{\mathcal{L}}\) ",
410429
"rmse": "RMSE [m]",
411430
"r2": r"\(r^2\)",
412431
"kernel": "Kernel",
413-
"mean_function": "Mean Function",
414-
"norm_x": "Norm. $x$",
415-
"norm_y": "Norm. $y$",
432+
"mean_function": "Mean F.",
433+
"norm_x": r"Norm. \(x\)",
434+
"norm_y": r"Norm. \(y\)",
416435
}
417436
)
418437

419-
# write plain CSV with all columns
420-
results.to_csv(
421-
os.path.join(DATA_PATH, f"gp_results_n_train_{n_train}.csv"),
422-
index=False,
423-
)
424-
# write LaTeX table with pretty strings
425-
pretty.to_latex(
438+
# write LaTeX table with df strings
439+
df.to_latex(
426440
os.path.join(DATA_PATH, f"gp_results_n_train_{n_train}.tex"),
427441
index=False,
428442
escape=False,
@@ -434,6 +448,8 @@ def fmt_val(metric, mean, sem):
434448
# python -m adbo.gp_exp &> gp_exp.log
435449
# gather_data()
436450
# run_exp(kernels=("Matern52",), mean_functions=("Constant",), n_train=25)
451+
# save_results_tex(n_train=50) # 100)
452+
437453
run_exp(
438454
kernels=(
439455
"Matern52",
@@ -455,5 +471,5 @@ def fmt_val(metric, mean, sem):
455471
# "Exponential",
456472
# "Periodic",
457473
),
458-
n_train=25,
474+
n_train=50,
459475
)

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# content of pytest.ini
22
[pytest]
33

4-
addopts = --doctest-modules --ignore=tcpips/exp --ignore=tcpips/pangeo.py --ignore=tcpips/example --ignore=tcpips/regrid.py --ignore=worst --ignore=chavas15 --ignore=adforce/fort15.py --ignore=adforce/oldwrap.py --ignore=ipynb --ignore=img --ignore=.secret --ignore=.vscode --ignore=w22/ps.py
4+
addopts = --doctest-modules --ignore=tcpips/exp --ignore=tcpips/pangeo.py --ignore=tcpips/example --ignore=tcpips/regrid.py --ignore=worst/tens.py --ignore=worst/vary_gamma_beta.py --ignore=worst/vary_noise.py --ignore=vary_samples_ns.py --ignore=chavas15 --ignore=adforce/fort15.py --ignore=adforce/oldwrap.py --ignore=ipynb --ignore=img --ignore=.secret --ignore=.vscode --ignore=w22/ps.py
55

66
doctest_encoding = latin1
77
doctest_optionflags = NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL

worst/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,28 @@ def retry_wrapper(max_retries: int = 10) -> callable:
145145
146146
Returns:
147147
callable: Wrapper function to recall the function in case of failure.
148+
149+
Example::
150+
>>> @retry_wrapper(max_retries=5)
151+
... def my_function():
152+
... # Some code that will fail
153+
... assert False
154+
>>> try:
155+
... my_function()
156+
... except Exception as e:
157+
... print(e)
158+
Exception:
159+
Retrying my_function 1/5
160+
Exception:
161+
Retrying my_function 2/5
162+
Exception:
163+
Retrying my_function 3/5
164+
Exception:
165+
Retrying my_function 4/5
166+
Exception:
167+
Retrying my_function 5/5
168+
Max retries exceeded for function my_function
169+
148170
"""
149171

150172
def retry_decorator(func: callable) -> callable:

worst/vary_gamma_beta.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Vary the shape and scale parameters of the GEV distribution."""
2+
13
import os
24
from typing import List
35
import hydra
@@ -32,8 +34,8 @@ def try_fit(
3234
3335
Args:
3436
z_star (float): z_star.
35-
beta (float): beta.
36-
gamma (float): gamma.
37+
beta (float): beta. Scale parameter.
38+
gamma (float): gamma. Shape parameter.
3739
ns (int): Number of samples.
3840
seed (int): Seed.
3941
quantiles (List[float], optional): Quantiles. Defaults to [1/100, 1/500].

0 commit comments

Comments
 (0)