88import gpflow
99from sithom .time import timeit
1010from sithom .plot import plot_defaults , label_subplots , get_dim
11+ from worst .utils import retry_wrapper
1112from .constants import DATA_PATH , EXP_PATH
1213
1314# from tf.keras.metrics import R2Score, RootMeanSquaredError
@@ -199,6 +200,7 @@ def fit_gp(
199200
200201
201202@timeit
203+ @retry_wrapper (max_retries = 10 )
202204def run_single_fit (
203205 norm_x : bool = True ,
204206 norm_y : bool = True ,
@@ -298,6 +300,13 @@ def run_exp(
298300 - Fit the GP model to the training set.
299301 - Evaluate the model on the test set (log liklihood, r2, rmse).
300302 - Save the results in a csv file and a tex file.
303+
304+ Args:
305+ kernels (tuple, optional): Kernels to use. Defaults to ("Matern52", "Matern12", "SE", "RationalQuadratic").
306+ mean_functions (tuple, optional): Mean functions to use. Defaults to ("Constant", "Linear", "Polynomial", "Zero", "Exponential", "Periodic").
307+ n_train (int, optional): Number of training samples. Defaults to 25.
308+ repeats (int, optional): Number of repeats for each experiment. Defaults to 100.
309+
301310 """
302311 results = pd .DataFrame (
303312 columns = [
@@ -370,59 +379,64 @@ def run_exp(
370379 ignore_index = True ,
371380 )
372381
382+ # write plain CSV with all columns
383+ results .to_csv (
384+ os .path .join (DATA_PATH , f"gp_results_n_train_{ n_train } .csv" ),
385+ index = False ,
386+ )
387+ save_results_tex (n_train = n_train )
388+
389+
390+ def save_results_tex (n_train : int = 25 ):
391+ df = pd .read_csv (os .path .join (DATA_PATH , f"gp_results_n_train_{ n_train } .csv" ))
392+
373393 # --- format <mean ± SEM> strings for LaTeX output, bolding the best ---
374- pretty = results .copy ()
375394
376395 # nice kernel names for LaTeX
377- pretty ["kernel" ] = (
378- pretty ["kernel" ]
379- .str .replace ("Matern52" , r"Matern -\(\frac {5}{2}\)" )
380- .str .replace ("Matern32" , r"Matern -\(\frac {3}{2}\)" )
381- .str .replace ("Matern12" , r"Matern -\(\frac {1}{2}\)" )
396+ df ["kernel" ] = (
397+ df ["kernel" ]
398+ .str .replace ("Matern52" , r"Mat\'ern -\({5}/ {2}\)" )
399+ .str .replace ("Matern32" , r"Mat\'ern -\({3}/ {2}\)" )
400+ .str .replace ("Matern12" , r"Mat\'ern -\({1}/ {2}\)" )
382401 .str .replace ("SE" , "SE" )
383402 )
384403
385404 # find best values
386405 best_vals = {
387- "log_likelihood" : results ["log_likelihood" ].max (), # highest is best
388- "rmse" : results ["rmse" ].min (), # lowest is best
389- "r2" : results ["r2" ].max (), # highest is best
406+ "log_likelihood" : df ["log_likelihood" ].max (), # highest is best
407+ "rmse" : df ["rmse" ].min (), # lowest is best
408+ "r2" : df ["r2" ].max (), # highest is best
390409 }
391410
392411 # helper: format mean ± sem, bold if mean equals best (allow tiny tol)
393- def fmt_val (metric , mean , sem ) :
394- base = f"{ mean :.3f} $ \\ pm$ { sem :.3f} "
412+ def fmt_val (metric : str , mean : float , sem : float ) -> str :
413+ base = f"{ mean :.3f} \\ ( \\ pm \\ ) { sem :.3f} "
395414 if abs (mean - best_vals [metric ]) < 1e-9 :
396- return f"\\ textbf{{{ mean :.3f} }} \\ (\\ pm\\ ) \\ textbf{{sem:.3f}}"
415+ return f"\\ textbf{{{ mean :.3f} }} \\ (\\ pm\\ ) \\ textbf{{{ sem :.3f} }}"
397416 return base
398417
399418 for metric in ("log_likelihood" , "rmse" , "r2" ):
400419 sem_col = metric + "_sem"
401- pretty [metric ] = pretty .apply (
420+ df [metric ] = df .apply (
402421 lambda row , m = metric : fmt_val (m , row [m ], row [sem_col ]), axis = 1
403422 )
404- pretty = pretty .drop (columns = [sem_col ])
423+ df = df .drop (columns = [sem_col ])
405424
406425 # rename columns for LaTeX header
407- pretty = pretty .rename (
426+ df = df .rename (
408427 columns = {
409- "log_likelihood" : r"Mean Log Likelihood \(\bar{\mathcal{L}}\) " ,
428+ "log_likelihood" : r"\(\bar{\mathcal{L}}\) " ,
410429 "rmse" : "RMSE [m]" ,
411430 "r2" : r"\(r^2\)" ,
412431 "kernel" : "Kernel" ,
413- "mean_function" : "Mean Function " ,
414- "norm_x" : "Norm. $x$ " ,
415- "norm_y" : "Norm. $y$ " ,
432+ "mean_function" : "Mean F. " ,
433+ "norm_x" : r "Norm. \(x\) " ,
434+ "norm_y" : r "Norm. \(y\) " ,
416435 }
417436 )
418437
419- # write plain CSV with all columns
420- results .to_csv (
421- os .path .join (DATA_PATH , f"gp_results_n_train_{ n_train } .csv" ),
422- index = False ,
423- )
424- # write LaTeX table with pretty strings
425- pretty .to_latex (
438+ # write LaTeX table with df strings
439+ df .to_latex (
426440 os .path .join (DATA_PATH , f"gp_results_n_train_{ n_train } .tex" ),
427441 index = False ,
428442 escape = False ,
@@ -434,6 +448,8 @@ def fmt_val(metric, mean, sem):
434448 # python -m adbo.gp_exp &> gp_exp.log
435449 # gather_data()
436450 # run_exp(kernels=("Matern52",), mean_functions=("Constant",), n_train=25)
451+ # save_results_tex(n_train=50) # 100)
452+
437453 run_exp (
438454 kernels = (
439455 "Matern52" ,
@@ -455,5 +471,5 @@ def fmt_val(metric, mean, sem):
455471 # "Exponential",
456472 # "Periodic",
457473 ),
458- n_train = 25 ,
474+ n_train = 50 ,
459475 )
0 commit comments