diff --git a/tc1d/tc1d.py b/tc1d/tc1d.py index 1a3adb0..88a418b 100644 --- a/tc1d/tc1d.py +++ b/tc1d/tc1d.py @@ -14,6 +14,11 @@ from scipy.linalg import solve from sklearn.model_selection import ParameterGrid from neighpy import NASearcher, NAAppraiser +import sys # BG: Required for MPI shutdown +import emcee # BG: For MCMC sampling +import copy +from emcee.mpi_pool import MPIPool # BG: For parallel MCMC with MPI +import corner # BG: Corner plots for MCMC # Import madtrax functions from madtrax import madtrax_apatite, madtrax_zircon @@ -613,37 +618,49 @@ def calculate_erosion_rate( vx_surf = vx_array[0] vx_max = vx_surf - # Constant erosion rate with a step-function change at a specified time - # Convert to inputting rates directly? elif params["ero_type"] == 2: - interval1 = myr2sec(params["ero_option2"]) - rate1 = kilo2base(params["ero_option1"]) / interval1 - transition_time1 = myr2sec(params["ero_option2"]) - # Handle case where ero_option4 and ero_option5 are not specified - if abs(params["ero_option4"]) <= 1.0e-8: - # Set ero_option4 to model duration - interval2 = t_total - myr2sec(params["ero_option2"]) - rate2 = kilo2base(params["ero_option3"]) / interval2 - rate3 = 0.0 - transition_time2 = t_total - else: - # Third rate/interval used - interval2 = myr2sec(params["ero_option4"] - params["ero_option2"]) - rate2 = kilo2base(params["ero_option3"]) / interval2 - interval3 = t_total - myr2sec(params["ero_option4"]) - rate3 = kilo2base(params["ero_option5"]) / interval3 - transition_time2 = myr2sec(params["ero_option4"]) - # First stage of erosion - if current_time < transition_time1: - vx_array[:] = rate1 - # Second stage of erosion - elif current_time < transition_time2: - vx_array[:] = rate2 - # Third stage of erosion - else: - vx_array[:] = rate3 + # BG: Extended step-function erosion model to support multiple intervals + # Collect erosion thicknesses (odd options) and transition times (even options) + thicknesses = [] + transition_times_sec = [] + total_time_Myr = params["t_total"] if isinstance(params["t_total"], float) else params["t_total"][0] + # Ensure total_time_Myr is in Myr (params["t_total"] is likely already in Myr here) + for opt_idx in range(1, 11, 2): # Check ero_option1,3,5,7,9 + thick = params.get(f"ero_option{opt_idx}", 0.0) + trans_idx = opt_idx + 1 + # Stop if no corresponding transition or transition is 0 (end of intervals) + if trans_idx > 10: + thicknesses.append(thick) + break + trans_time = params.get(f"ero_option{trans_idx}", 0.0) + # If transition time is zero or beyond total time, this is the final interval + if (not trans_time) or (trans_time >= total_time_Myr - 1e-8): + thicknesses.append(thick) + break + # Otherwise, record this interval and continue + thicknesses.append(thick) + transition_times_sec.append(myr2sec(trans_time)) + # Compute erosion rates (m/s) for each interval + rates = [] + prev_time = 0.0 + for j, thick_km in enumerate(thicknesses): + if j < len(transition_times_sec): + interval_duration = transition_times_sec[j] - prev_time + prev_time = transition_times_sec[j] + else: + interval_duration = t_total - prev_time # t_total is in seconds + # Convert thickness (km) to m, then rate = thickness / duration + rates.append(kilo2base(thick_km) / interval_duration if interval_duration > 0 else 0.0) + # Assign velocity according to current time + # current_time is the simulation time (seconds) passed into this function + applied_rate = rates[-1] # default to last rate + for k, t_sec in enumerate(transition_times_sec): + if current_time < t_sec: + applied_rate = rates[k] + break + vx_array[:] = applied_rate vx_surf = vx_array[0] - vx_max = max(abs(rate1), abs(rate2), abs(rate3)) + vx_max = max(abs(r) for r in rates) # Exponential erosion rate decay with a set characteristic time # Convert to inputting rate directly? @@ -770,18 +787,18 @@ def calculate_exhumation_magnitude( ero_option6, ero_option7, ero_option8, - t_total, + ero_option9=0.0, # BG: Added optional erosion option 9 + ero_option10=0.0, # BG: Added optional erosion option 10 + t_total=0.0, ): """Calculates erosion magnitude in kilometers.""" - # Initialize fault exhumation magnitude for ero_type = 7 - fault_magnitude = 0.0 # Constant erosion rate if ero_type == 1: magnitude = ero_option1 elif ero_type == 2: - magnitude = ero_option1 + ero_option3 + ero_option5 + magnitude = ero_option1 + ero_option3 + ero_option5 + ero_option7 + ero_option9 # BG: Sum all provided erosion thickness values for extended intervals elif ero_type == 3: magnitude = ero_option1 @@ -807,29 +824,24 @@ def calculate_exhumation_magnitude( elif ero_type == 7: # Initial exhumation phase, if applicable - magnitude1 = myr2sec(ero_option6) * mmyr2ms(ero_option5) + magnitude = myr2sec(ero_option6) * mmyr2ms(ero_option5) # Handle case that ero_option8 is not specified (i.e., second phase of constant exhumation) if abs(ero_option8) <= 1.0e-8: rate_change_time2 = t_total else: rate_change_time2 = myr2sec(ero_option8) # Extensional/compressional tectonics phase - magnitude2 = (rate_change_time2 - myr2sec(ero_option6)) * ( + magnitude += (rate_change_time2 - myr2sec(ero_option6)) * ( ero_option2 * mmyr2ms(abs(ero_option1)) * np.sin(deg2rad(ero_option3)) ) # Final exhumation phase, if applicable - magnitude3 = (t_total - rate_change_time2) * mmyr2ms(ero_option7) - # Make magnitude 2 negative if in hanging wall - fault_magnitude = magnitude1 + magnitude2 + magnitude3 - if fault_magnitude <= kilo2base(ero_option4): - magnitude2 = -magnitude2 - magnitude = magnitude1 + magnitude2 + magnitude3 + magnitude += (t_total - rate_change_time2) * mmyr2ms(ero_option7) + magnitude /= 1000.0 else: raise MissingOption("Bad erosion type. Type should be between 1 and 6.") - # Return values in km - return magnitude / kilo2base(1.0), fault_magnitude / kilo2base(1.0) + return magnitude def calculate_pressure(density, dx, g=9.81): @@ -1190,7 +1202,6 @@ def init_params( misfit_type=1, plot_results=True, display_plots=True, - plot_ma=True, plot_depth_history=False, invert_tt_plot=False, t_plots=[0.1, 1, 5, 10, 20, 30, 50], @@ -1350,8 +1361,6 @@ def init_params( Plot calculated results. display_plots : bool, default=True Display plots on screen. - plot_ma : bool, default=True - Plot time in Ma rather than Myr from start of model. plot_depth_history : bool, default=False Plot depth history on thermal history plot. invert_tt_plot : bool, default=False @@ -1402,7 +1411,6 @@ def init_params( "plot_results": plot_results, "save_plots": save_plots, "display_plots": display_plots, - "plot_ma": plot_ma, "plot_depth_history": plot_depth_history, "invert_tt_plot": invert_tt_plot, # Batch mode not supported when called as a function @@ -1535,6 +1543,8 @@ def prep_model(params): "ero_option6", "ero_option7", "ero_option8", + "ero_option9", # BG: Added for extended intervals + "ero_option10", # BG: Added for extended intervals "mantle_adiabat", "rho_crust", "cp_crust", @@ -1584,8 +1594,13 @@ def prep_model(params): params[key] = params[key][0] run_model(params) else: - # Run in batch mode - batch_run(params, batch_params) + # BG: Choose between NA and MCMC explicitly + if params.get("inverse_mode", "NA").upper() == "MCMC": + batch_run_mcmc(params, batch_params) + elif params.get("inverse_mode", "NA").upper() == "NA": + batch_run_na(params, batch_params) + else: + raise ValueError(f"Unknown inversion mode: {params.get('inverse_mode')}") else: # If called as a function, check for lists and their lengths @@ -1656,6 +1671,7 @@ def log_output(params, batch_mode=False): "Mantle removal end time (Ma),Erosion model type,Erosion model option 1," "Erosion model option 2,Erosion model option 3,Erosion model option 4,Erosion model option 5," "Erosion model option 6,Erosion model option 7,Erosion model option 8," + "Erosion model option 9,Erosion model option 10," # BG: Added new columns in header "Initial Moho depth (km),Initial Moho temperature (C)," "Initial surface heat flow (mW m^-2),Initial surface elevation (km)," "Final Moho depth (km),Final Moho temperature (C),Final surface heat flow (mW m^-2)," @@ -1676,8 +1692,242 @@ def log_output(params, batch_mode=False): return outfile +# ========== MCMC helper functions (for multiprocessing and MPI) ========== BG + +# BG: Global variables used by log_probability +global_bounds = None +global_param_names = None +global_params = None +global_max_exhumation = 35.0 +max_burial = 15.0 + +def log_prior(x): + for val, (low, high) in zip(x, global_bounds): + if not (low <= val <= high): + return -np.inf + + param_dict = dict(zip(global_param_names, x)) + ero1 = param_dict.get("ero_option1", global_params.get("ero_option1", 0.0)) + ero3 = param_dict.get("ero_option3", global_params.get("ero_option3", 0.0)) + ero5 = param_dict.get("ero_option5", global_params.get("ero_option5", 0.0)) + ero7 = param_dict.get("ero_option7", global_params.get("ero_option7", 0.0)) + ero9 = param_dict.get("ero_option9", global_params.get("ero_option9", 0.0)) + + # Interval 3 + if "ero_option3" in param_dict: + upper = global_max_exhumation - ero1 + lower = -max_burial - ero1 + if not (lower <= ero3 <= upper): + return -np.inf + + # Interval 5 + if "ero_option5" in param_dict: + upper = global_max_exhumation - (ero1 + ero3) + lower = -max_burial - (ero1 + ero3) + if not (lower <= ero5 <= upper): + return -np.inf + + # Interval 7 + if "ero_option7" in param_dict: + upper = global_max_exhumation - (ero1 + ero3 + ero5) + lower = -max_burial - (ero1 + ero3 + ero5) + if not (lower <= ero7 <= upper): + return -np.inf + + # Interval 9 + if "ero_option9" in param_dict: + upper = global_max_exhumation - (ero1 + ero3 + ero5 + ero7) + lower = -max_burial - (ero1 + ero3 + ero5 + ero7) + if not (lower <= ero9 <= upper): + return -np.inf + + # Total thickness must remain ≥ 0 km (cannot end above surface) + cumulative = ero1 + ero3 + ero5 + ero7 + ero9 + if cumulative < 0.0: + return -np.inf + + return 0.0 + +def log_likelihood(x): + param_dict = dict(zip(global_param_names, x)) + new_dict = {} + for k, v in param_dict.items(): + try: + new_dict[k] = float(v[0]) if isinstance(v, list) else float(v) + except (ValueError, TypeError): + print(f"[WARNING] Could not convert {k}={v} to float.") + return -np.inf + params_local = copy.deepcopy(global_params) + params_local.update(new_dict) + cleaned_dict = {k: float(v) for k, v in new_dict.items()} + print(f"[MCMC] Testing params: {cleaned_dict}") + try: + misfit = run_model(params_local) + print(f"Misfit: {misfit}") + return -misfit + except Exception as e: + print(f"[ERROR] run_model failed: {e}") + return -np.inf + +def log_probability(x): + lp = log_prior(x) + if not np.isfinite(lp): + return -np.inf + ll = log_likelihood(x) + return lp + ll + +# ========== Batch MCMC Runner ========== BG + +def batch_run_mcmc(params, batch_params): + """Runs TC1D in batch mode""" + + param_list = list(ParameterGrid(batch_params)) + print(f"--- Starting batch processor for {len(param_list)} models ---\n") + + if params.get("inverse_mode", "NA").upper() != "MCMC": # BG: Only run this block if MCMC is selected + print("[INFO] Skipping MCMC block since inverse_mode is not 'MCMC'") + return + + success = 0 + failed = 0 + + print(f"--- Starting MCMC inverse mode ---\n") + log_output(params, batch_mode=True) + + # BG: Extract parameters that are being varied to define bounds of the inversion + filtered_params = {k: v for k, v in batch_params.items() if len(v) > 1} + bounds = list(filtered_params.values()) + param_names = list(filtered_params.keys()) + ndim = len(param_names) + max_exhumation = 35.0 + + model = param_list[0] # BG: Start from the first parameter combination + for key in batch_params: + params[key] = model[key] + + # BG: Set global variables for MPI pickling compatibility + global global_bounds, global_param_names, global_params, global_max_exhumation + global_bounds = bounds + global_param_names = param_names + global_params = params + global_max_exhumation = max_exhumation + + # BG: Initialize walkers and sampler using MPI Pool + nwalkers = 16 + nsteps = 200 + discard = 30 + thin = 3 + + p0 = [ + [np.random.uniform(low, high) for (low, high) in global_bounds] + for _ in range(nwalkers) + ] + + pool = MPIPool() + if not pool.is_master(): + pool.wait() + sys.exit(0) + + sampler = emcee.EnsembleSampler( + nwalkers, + ndim, + log_probability, + pool=pool # BG: Use MPI pool here + ) + + sampler.run_mcmc(p0, nsteps) + pool.close() + + # BG: Post-processing with legacy-compatible attributes + chain = sampler.chain + log_probs = sampler.lnprobability + + # BG: Flatten chains manually (legacy emcee version does not support .get_chain()) + flat_samples = chain[:, discard::thin, :].reshape(-1, ndim) + flat_log_probs = log_probs[:, discard::thin].reshape(-1) + + # BG: Check if any valid samples remain after burn-in + if len(log_probs) == 0: + print("[ERROR] No valid samples after burn-in. Aborting analysis.") + return + + # BG: Identify and print the best parameter set (lowest misfit) + best_idx = np.argmax(flat_log_probs) + best = flat_samples[best_idx] + best_dict = dict(zip(param_names, best)) # Manquait ici + print(f" The best parameters are: { {k: float(v) for k, v in best_dict.items()} }") + + # BG: Plot evolution of misfit values + plt.figure() + neg_log_probs = -flat_log_probs + plt.plot(neg_log_probs, ".", markersize=2) + plt.scatter(best_idx, neg_log_probs[best_idx], c="g", s=10) + plt.xlabel("Sample Index") + plt.ylabel("Misfit") + plt.title("MCMC Misfit Values") + plt.yscale("log") + plt.savefig("mcmc_misfit.png") + + # BG: Plot MCMC chains to assess parameter convergence + plt.figure(figsize=(10, ndim * 2)) + for i in range(ndim): + plt.subplot(ndim, 1, i + 1) + for walker in chain[:, :, i]: + plt.plot(walker, alpha=0.4) + plt.ylabel(param_names[i]) + if i == 0: + plt.title("MCMC Chains for Each Parameter") + plt.xlabel("Step") + plt.tight_layout() + plt.savefig("mcmc_chains.png") + print("[MCMC] Chain plot saved as:", os.path.abspath("mcmc_chains.png")) + + # BG: Use corner plot to visualize marginal distributions and parameter correlations + figure = corner.corner( + flat_samples, + labels=param_names, + truths=best, + show_titles=True, + title_fmt=".2f", + title_kwargs={"fontsize": 10} + ) + corner_plot_path = "mcmc_corner.png" + figure.savefig(corner_plot_path) + print("[MCMC] Corner plot saved as:", os.path.abspath(corner_plot_path)) + + # BG: Generate one scatter plot with marginal histograms per parameter pair + from itertools import combinations + for i, j in combinations(range(ndim), 2): + x, y = flat_samples[:, i], flat_samples[:, j] + fig = plt.figure(constrained_layout=True) + gs = fig.add_gridspec(4, 4) + ax = fig.add_subplot(gs[1:, :-1]) + ax_histx = fig.add_subplot(gs[0, :-1], sharex=ax) + ax_histy = fig.add_subplot(gs[1:, -1], sharey=ax) + + sc = ax.scatter(x, y, c=neg_log_probs, cmap="viridis", marker="x") + ax.scatter(best[i], best[j], color="red", marker="x", label="Best") + ax.set_xlabel(param_names[i]) + ax.set_ylabel(param_names[j]) + fig.colorbar(sc, ax=ax, orientation='horizontal', label='Misfit') + + ax_histx.hist(x, bins=15, color="grey") + ax_histy.hist(y, bins=15, orientation="horizontal", color="grey") + fig.suptitle(f"Scatter: {param_names[i]} vs {param_names[j]}") + plt.savefig(f"mcmc_scatter_{param_names[i]}_vs_{param_names[j]}.png") + plt.close(fig) + + # BG: Summary of saved outputs + success += 1 + print("\n[MCMC] Misfit plot saved as:", os.path.abspath("mcmc_misfit.png")) + print("[MCMC] Chain plot saved as:", os.path.abspath("mcmc_chains.png")) + print("[MCMC] Corner plot saved as:", os.path.abspath("mcmc_corner.png")) + if ndim >= 2: + print("[MCMC] Pairwise scatter plots saved (one per parameter pair).") + print(f"\n--- Execution complete ({success} succeeded, {failed} failed) ---") -def batch_run(params, batch_params): +# ========== Batch NA runner ========== BG +def batch_run_na(params, batch_params): """Runs TC1D in batch mode""" param_list = list(ParameterGrid(batch_params)) @@ -1688,14 +1938,15 @@ def batch_run(params, batch_params): failed = 0 # If inverse mode is enabled, run with the neighbourhood algorithm - if params["inverse_mode"] == True: + if params.get("inverse_mode", "NA").upper() == "NA": # BG: Only run NA block if inverse_mode is 'NA' print(f"--- Starting inverse mode ---\n") log_output(params, batch_mode=True) # Batch params only for testing # batch_params = {'max_depth': [125.0, 130], 'nx': [251], 'temp_surf': [0.0], 'temp_base': [1300.0], 't_total': [50.0], 'dt': [5000.0], 'vx_init': [0.0], 'init_moho_depth': [50.0], 'removal_fraction': [0.0], 'removal_time': [0.0], 'ero_type': [1], 'ero_option1': [10.0, 15.0], 'ero_option2': [0.0], 'ero_option3': [0.0], 'ero_option4': [0.0], 'ero_option5': [0.0], 'mantle_adiabat': [True], 'rho_crust': [2850.0], 'cp_crust': [800.0], 'k_crust': [2.75], 'heat_prod_crust': [0.5], 'alphav_crust': [3e-05], 'rho_mantle': [3250.0], 'cp_mantle': [1000.0], 'k_mantle': [2.5], 'heat_prod_mantle': [0.0], 'alphav_mantle': [3e-05], 'rho_a': [3250.0], 'k_a': [20.0], 'ap_rad': [45.0], 'ap_uranium': [10.0], 'ap_thorium': [40.0], 'zr_rad': [60.0], 'zr_uranium': [100.0], 'zr_thorium': [40.0], 'pad_thist': [False], 'pad_time': [0.0]} - max_ehumation = 35.0 + max_exhumation = 35.0 + max_burial = 15.0 # Starting model model = param_list[0] @@ -1713,39 +1964,57 @@ def batch_run(params, batch_params): # Objective function to be minimised, run for misfit def objective(x): - # Update bounds + # Map sampled values x to the corresponding parameter names for key, value in zip(filtered_params, x): filtered_params[key] = value - # Additional case-by-case rules for params - # Default final values - ero3_final = filtered_params["ero_option3"] - ero5_final = filtered_params["ero_option5"] - - # Ensure ero_option3 does not exceed the available exhumation - ero3_final = max( - 0, - min( - filtered_params["ero_option3"], - max_ehumation - filtered_params["ero_option1"], - ), - ) - # Ensure ero_option5 does not exceed the available exhumation - ero5_final = max( - 0, - min( - filtered_params["ero_option5"], - max_ehumation - (filtered_params["ero_option1"] + ero3_final), - ), - ) - - # Update params only when conditions have been examined - filtered_params["ero_option3"] = ero3_final - filtered_params["ero_option5"] = ero5_final + # BG: Get erosion parameters with default fallback from global params + ero1 = filtered_params.get("ero_option1", params.get("ero_option1", 0.0)) + ero3_final = filtered_params.get("ero_option3", params.get("ero_option3", 0.0)) + ero5_final = filtered_params.get("ero_option5", params.get("ero_option5", 0.0)) + ero7_final = filtered_params.get("ero_option7", params.get("ero_option7", 0.0)) + ero9_final = filtered_params.get("ero_option9", params.get("ero_option9", 0.0)) + + # === BG: Apply physical constraints to erosion parameters (allow burial) === + + # Ero option 3 + if "ero_option3" in filtered_params: + upper_bound = max_exhumation - ero1 # cannot exhume >35 km + lower_bound = -max_burial - ero1 # cannot bury >15 km + ero3_final = max(lower_bound, min(ero3_final, upper_bound)) + filtered_params["ero_option3"] = ero3_final + + # Ero option 5 + if "ero_option5" in filtered_params: + upper_bound = max_exhumation - (ero1 + ero3_final) + lower_bound = -max_burial - (ero1 + ero3_final) + ero5_final = max(lower_bound, min(ero5_final, upper_bound)) + filtered_params["ero_option5"] = ero5_final + + # Ero option 7 + if "ero_option7" in filtered_params: + upper_bound = max_exhumation - (ero1 + ero3_final + ero5_final) + lower_bound = -max_burial - (ero1 + ero3_final + ero5_final) + ero7_final = max(lower_bound, min(ero7_final, upper_bound)) + filtered_params["ero_option7"] = ero7_final + + # Ero option 9 + if "ero_option9" in filtered_params: + upper_bound = max_exhumation - (ero1 + ero3_final + ero5_final + ero7_final) + lower_bound = -max_burial - (ero1 + ero3_final + ero5_final + ero7_final) + ero9_final = max(lower_bound, min(ero9_final, upper_bound)) + filtered_params["ero_option9"] = ero9_final # Add bounds to parameters params.update(filtered_params) - print(f" The current values are: {filtered_params}") + cleaned_filtered = {k: float(v) for k, v in filtered_params.items()} + print(f" The current values are: {cleaned_filtered}") + + # BG: Cumulative constraint to prevent elevation above the surface + cumulative_erosion = ero1 + ero3_final + ero5_final + ero7_final + ero9_final + if cumulative_erosion < 0: + print("Rejected: model would place the sample above surface.") + return 1e10 # Or some large misfit to reject the model misfit = run_model(params) # misfit = x[0]*2 + x[1]*2 + x[2]*2 + 100*2 #lighter test function @@ -1755,20 +2024,52 @@ def objective(x): # Initialize NA searcher searcher = NASearcher( objective, - ns=200, # 16 #100, # number of samples per iteration #10 - nr=100, # 8 #10, # number of cells to resample #1 + ns=40, # 16 #100, # number of samples per iteration #10 + nr=20, # 8 #10, # number of cells to resample #1 ni=100, # 100, # size of initial random search #1 - n=30, # 20, # number of iterations #1 + n=10, # 20, # number of iterations #1 bounds=bounds, ) # Run the direct search phase searcher.run() # results stored in searcher.samples and searcher.objectives - # Optionally adjust the samples for appraiser + # BG: Apply constraints after search using parameter names to avoid index errors for i in searcher.samples: - i[2] = min(max_ehumation - i[0], i[2]) - i[4] = min(max_ehumation - (i[0] + i[2]), i[4]) + param_dict = dict(zip(filtered_params.keys(), i)) + + ero1 = param_dict.get("ero_option1", 0.0) + ero3 = param_dict.get("ero_option3", 0.0) + ero5 = param_dict.get("ero_option5", 0.0) + ero7 = param_dict.get("ero_option7", 0.0) + ero9 = param_dict.get("ero_option9", 0.0) + + if "ero_option3" in param_dict: + upper = max_exhumation - ero1 + lower = -max_burial - ero1 + param_dict["ero_option3"] = max(lower, min(ero3, upper)) + + if "ero_option5" in param_dict: + upper = max_exhumation - (ero1 + param_dict.get("ero_option3", 0.0)) + lower = -max_burial - (ero1 + param_dict.get("ero_option3", 0.0)) + param_dict["ero_option5"] = max(lower, min(ero5, upper)) + + if "ero_option7" in param_dict: + upper = max_exhumation - ( + ero1 + param_dict.get("ero_option3", 0.0) + param_dict.get("ero_option5", 0.0)) + lower = -max_burial - (ero1 + param_dict.get("ero_option3", 0.0) + param_dict.get("ero_option5", 0.0)) + param_dict["ero_option7"] = max(lower, min(ero7, upper)) + + if "ero_option9" in param_dict: + upper = max_exhumation - (ero1 + param_dict.get("ero_option3", 0.0) + param_dict.get("ero_option5", + 0.0) + param_dict.get( + "ero_option7", 0.0)) + lower = -max_burial - (ero1 + param_dict.get("ero_option3", 0.0) + param_dict.get("ero_option5", + 0.0) + param_dict.get( + "ero_option7", 0.0)) + param_dict["ero_option9"] = max(lower, min(ero9, upper)) + + i[:] = [param_dict[k] for k in filtered_params.keys()] appraiser = NAAppraiser( initial_ensemble=searcher.samples, # points of parameter space already sampled @@ -1784,11 +2085,44 @@ def objective(x): print(f"Appraiser covariance: {appraiser.covariance}") print(f"Appraiser covariance error: {appraiser.sample_covariance_error}") - # Best param + # BG: Safely extract best parameter set using param names best = searcher.samples[np.argmin(searcher.objectives)] - # optional param adjustments, MAKE SURE THEY ARE UPDATED - best[2] = min(max_ehumation - best[0], best[2]) - best[4] = min(max_ehumation - (best[0] + best[2]), best[4]) + best_dict = dict(zip(filtered_params.keys(), best)) + ero1 = best_dict.get("ero_option1", 0.0) + ero3 = best_dict.get("ero_option3", 0.0) + + # BG: Apply constraints only to parameters being inverted + ero3 = best_dict.get("ero_option3", 0.0) + ero5 = best_dict.get("ero_option5", 0.0) + ero7 = best_dict.get("ero_option7", 0.0) + ero9 = best_dict.get("ero_option9", 0.0) + + if "ero_option3" in best_dict: + upper = max_exhumation - ero1 + lower = -max_burial - ero1 + best_dict["ero_option3"] = max(lower, min(ero3, upper)) + + if "ero_option5" in best_dict: + upper = max_exhumation - (ero1 + best_dict.get("ero_option3", 0.0)) + lower = -max_burial - (ero1 + best_dict.get("ero_option3", 0.0)) + best_dict["ero_option5"] = max(lower, min(ero5, upper)) + + if "ero_option7" in best_dict: + upper = max_exhumation - (ero1 + best_dict.get("ero_option3", 0.0) + best_dict.get("ero_option5", 0.0)) + lower = -max_burial - (ero1 + best_dict.get("ero_option3", 0.0) + best_dict.get("ero_option5", 0.0)) + best_dict["ero_option7"] = max(lower, min(ero7, upper)) + + if "ero_option9" in best_dict: + upper = max_exhumation - ( + ero1 + best_dict.get("ero_option3", 0.0) + best_dict.get("ero_option5", 0.0) + best_dict.get( + "ero_option7", 0.0)) + lower = -max_burial - ( + ero1 + best_dict.get("ero_option3", 0.0) + best_dict.get("ero_option5", 0.0) + best_dict.get( + "ero_option7", 0.0)) + best_dict["ero_option9"] = max(lower, min(ero9, upper)) + + # BG: Rebuild best list in correct parameter order + best[:] = [best_dict[k] for k in filtered_params.keys()] print(f" The best parameters are: {best}") # Plot for misfit @@ -1805,8 +2139,9 @@ def objective(x): transform=plt.gca().transAxes, ha="right", ) - # plt.show() - plt.savefig("misfit.png") + plt.savefig("NA_misfit.png") + if 'fig' in locals(): + plt.close(fig) # Plot for 2 params if len(bounds) == 2: @@ -1844,14 +2179,13 @@ def objective(x): ax.set_xlabel(list(filtered_params.keys())[0]) ax.set_ylabel(list(filtered_params.keys())[1]) fig.colorbar(scatter1, location="bottom", label="Misfit") - # Scatterplots test - # # Histograms ax_histx.hist(x_appraiser, bins=15, color="grey") ax_histy.hist(y_appraiser, bins=15, color="grey", orientation="horizontal") - plt.show() - # plt.savefig("scatter.png") + plt.savefig("NA_scatter.png") + if 'fig' in locals(): + plt.close(fig) # NA covariance matrix plot paramkeys = list(filtered_params.keys()) @@ -1875,48 +2209,80 @@ def objective(x): ha="center", va="center", ) - # plt.show() - # plt.savefig("matrix.png") + plt.savefig("NA_covariance_matrix.png") + if 'fig' in locals(): + plt.close(fig) - # Voronoi cells plot + # BG: Voronoi plot for 2 or more parameters from scipy.spatial import Voronoi, voronoi_plot_2d - fig, axs = plt.subplots(5, 5, figsize=(10, 10), tight_layout=True) - for i in range(5): - for j in range(5): - if j < i: - vor = Voronoi(searcher.samples[:, [i, j]]) - voronoi_plot_2d( - vor, - ax=axs[i, j], - show_vertices=False, - show_points=False, - line_width=0.5, - ) - axs[i, j].scatter( - best[i], - best[j], - c="g", - marker="x", - s=100, - label="Best model", - zorder=10, - ) - axs[i, j].set_xlim(searcher.bounds[i]) - axs[i, j].set_ylim(searcher.bounds[j]) - axs[i, j].set_xticks([]) - axs[i, j].set_yticks([]) - else: - axs[i, j].set_visible(False) - handles, labels = axs[1, 0].get_legend_handles_labels() - by_label = dict(zip(labels, handles)) - fig.legend( - by_label.values(), - by_label.keys(), - loc="lower left", - bbox_to_anchor=(0.6, 0.25), - ) - fig.savefig("voronoi.png") + samples = searcher.samples + nparams = samples.shape[1] + + if nparams == 2: + # BG: Classic 2D Voronoi plot + vor = Voronoi(samples) + fig, ax = plt.subplots(figsize=(6, 6)) + voronoi_plot_2d(vor, ax=ax, show_vertices=False, show_points=False, line_width=0.5) + ax.scatter(best[0], best[1], c="g", marker="x", s=100, label="Best model", zorder=10) + ax.set_xlim(bounds[0]) + ax.set_ylim(bounds[1]) + ax.set_xlabel(list(filtered_params.keys())[0]) + ax.set_ylabel(list(filtered_params.keys())[1]) + ax.legend(loc="lower right") + plt.tight_layout() + plt.savefig("NA_voronoi.png") + if 'fig' in locals(): + plt.close(fig) + + elif nparams > 2: + # BG: Pairwise Voronoi plots for all parameter pairs (lower triangle) + fig, axs = plt.subplots(nparams, nparams, figsize=(2.5 * nparams, 2.5 * nparams), tight_layout=True) + plt.suptitle("Voronoi pairwise plots") + for i in range(nparams): + for j in range(nparams): + if j < i: + vor = Voronoi(samples[:, [j, i]]) + voronoi_plot_2d(vor, ax=axs[i, j], show_vertices=False, show_points=False, line_width=0.5) + axs[i, j].scatter(best[j], best[i], c="g", marker="x", s=100, label="Best model", zorder=10) + axs[i, j].set_xlim(searcher.bounds[j]) + axs[i, j].set_ylim(searcher.bounds[i]) + axs[i, j].set_xticks([]) + axs[i, j].set_yticks([]) + if i == nparams - 1: + axs[i, j].set_xlabel(paramkeys[j]) + if j == 0: + axs[i, j].set_ylabel(paramkeys[i]) + else: + axs[i, j].set_visible(False) + + handles, labels = axs[1, 0].get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + fig.legend(by_label.values(), by_label.keys(), loc="lower left", bbox_to_anchor=(0.6, 0.25)) + plt.savefig("NA_voronoi.png") + if 'fig' in locals(): + plt.close(fig) + + # BG: Plot univariate histograms for each inverted parameter + fig, axs = plt.subplots(nparams, 1, figsize=(5, 2.5 * nparams), tight_layout=True) + for i in range(nparams): + axs[i].hist(searcher.samples[:, i], bins=15, color="grey", edgecolor="black") + axs[i].set_xlabel(paramkeys[i]) + axs[i].set_ylabel("Count") + axs[i].set_title(f"Histogram of {paramkeys[i]}") + + plt.suptitle("Parameter distributions from NA search", y=1.02) + plt.savefig("NA_histograms.png") + if 'fig' in locals(): + plt.close(fig) + + # BG: Summary of saved outputs + print("\n[NA] Misfit plot saved as:", os.path.abspath("NA_misfit.png")) + print("[NA] Histograms saved as:", os.path.abspath("NA_histograms.png")) + if len(bounds) == 2: + print("[NA] Scatter plot saved as:", os.path.abspath("NA_scatter.png")) + print("[NA] Covariance matrix plot saved as:", os.path.abspath("NA_covariance_matrix.png")) + print("[NA] Voronoi plot saved as:", os.path.abspath("voronoi.png")) print("Inverse mode complete") success += 1 @@ -1946,14 +2312,13 @@ def objective(x): f'{params["rho_crust"]:.4f},{params["removal_fraction"]:.4f},{params["removal_start_time"]:.4f},' f'{params["removal_end_time"]:.4f},' f'{params["ero_type"]},{params["ero_option1"]:.4f},' - f'{params["ero_option2"]:.4f},{params["ero_option3"]:.4f},{params["ero_option4"]:.4f},{params["ero_option5"]:.4f},{params["ero_option6"]:.4f},{params["ero_option7"]:.4f},{params["ero_option8"]:.4f},{params["init_moho_depth"]:.4f},,,,,,,,,{params["ap_rad"]:.4f},{params["ap_uranium"]:.4f},' + f'{params["ero_option2"]:.4f},{params["ero_option3"]:.4f},{params["ero_option4"]:.4f},{params["ero_option5"]:.4f},{params["ero_option6"]:.4f},{params["ero_option7"]:.4f},{params["ero_option8"]:.4f},{params["ero_option9"]:.4f},{params["ero_option10"]:.4f},{params["init_moho_depth"]:.4f},,,,,,,,,{params["ap_rad"]:.4f},{params["ap_uranium"]:.4f},' f'{params["ap_thorium"]:.4f},{params["zr_rad"]:.4f},{params["zr_uranium"]:.4f},{params["zr_thorium"]:.4f},,,,,,,,,,,,,,,\n' ) failed += 1 print(f"\n--- Execution complete ({success} succeeded, {failed} failed) ---") - def run_model(params): # Say hello if not params["batch_mode"]: @@ -1993,7 +2358,7 @@ def run_model(params): vx_hist = np.zeros(nt) # Calculate exhumation magnitude - exhumation_magnitude, fault_exhumation_magnitude = calculate_exhumation_magnitude( + exhumation_magnitude = calculate_exhumation_magnitude( params["ero_type"], params["ero_option1"], params["ero_option2"], @@ -2003,6 +2368,8 @@ def run_model(params): params["ero_option6"], params["ero_option7"], params["ero_option8"], + params["ero_option9"], # BG: Pass additional erosion option 9 + params["ero_option10"], # BG: Pass additional erosion option 10 t_total, ) @@ -2058,9 +2425,7 @@ def run_model(params): # Define final fault depth for erosion model 7 if params["ero_type"] == 7: # Set fault depth for extension - fault_depth = kilo2base(params["ero_option4"]) - kilo2base( - fault_exhumation_magnitude - ) + fault_depth = kilo2base(params["ero_option4"]) - kilo2base(exhumation_magnitude) # if params["ero_option1"] >= 0.0: # fault_depth = kilo2base(params["ero_option4"]) - kilo2base(exhumation_magnitude) ## Set fault depth for convergence @@ -2240,12 +2605,8 @@ def run_model(params): else: colors = plt.cm.viridis_r(np.linspace(0, 1, len(t_plots))) ax1.plot(temp_init, -x / 1000, "k:", label="Initial") - if params["plot_ma"]: - time_label = f"{params["t_total"]:.1f} Ma" - else: - time_label = "0.0 Myr" - ax1.plot(temp_prev, -x / 1000, "k-", label=time_label) - ax2.plot(density_init, -x / 1000, "k-", label=time_label) + ax1.plot(temp_prev, -x / 1000, "k-", label="0 Myr") + ax2.plot(density_init, -x / 1000, "k-", label="0 Myr") # Calculate model times when particles reach surface surface_times = myr2sec(params["t_total"] - surface_times_ma) @@ -2588,7 +2949,7 @@ def run_model(params): temp_hists[i][idx] = interp_temp_new(moho_depth) # Otherwise, record temperature at current depth else: - temp_hists[i][idx] = interp_temp_new(depths[i]) + temp_hists[i][idx] = interp_temp_new(max(0.0, depths[i])) # BG: prevent depth < 0 # Store pressure history # Check whether point is very close to the surface @@ -2597,7 +2958,7 @@ def run_model(params): elif depths[i] > moho_depth and params["fixed_moho"]: pressure_hists[i][idx] = interp_pressure(moho_depth) else: - pressure_hists[i][idx] = interp_pressure(depths[i]) + pressure_hists[i][idx] = interp_pressure(max(0.0, depths[i])) # BG: prevent depth < 0 # Print array values if debugging is on if params["debug"]: @@ -2639,21 +3000,17 @@ def run_model(params): if j == num_pass - 1: if params["plot_results"] and more_plots: if curtime > t_plots[plotidx]: - if params["plot_ma"]: - time_label = f"{(params["t_total"] - t_plots[plotidx] / myr2sec(1)):.1f} Ma" - else: - time_label = f"{t_plots[plotidx] / myr2sec(1):.1f} Myr" ax1.plot( temp_new, -x / 1000, "-", - label=time_label, + label=f"{t_plots[plotidx] / myr2sec(1):.1f} Myr", color=colors[plotidx], ) ax2.plot( density_new, -x / 1000, - label=time_label, + label=f"{t_plots[plotidx] / myr2sec(1):.1f} Myr", color=colors[plotidx], ) if plotidx == len(t_plots) - 1: @@ -2763,7 +3120,6 @@ def run_model(params): else: tt_new = tt_orig.rename(wd / "csv" / tt_orig) ttdp_new = ttdp_orig.rename(wd / "csv" / ttdp_orig) - # FIXME: Commenting this out for now ftl_new = ftl_orig.rename(wd / "csv" / ftl_orig) if params["echo_ages"]: @@ -3066,15 +3422,11 @@ def run_model(params): xmin = params["temp_surf"] # Add 10% to max T and round to nearest 100 xmax = round(1.1 * temp_new.max(), -2) - if params["plot_ma"]: - time_label = "0.0 Ma" - else: - time_label = f"{curtime / myr2sec(1):.1f} Myr" ax1.plot( temp_new, -x / 1000, "-", - label=time_label, + label=f"{curtime / myr2sec(1):.1f} Myr", color=colors[-1], ) ax1.plot( @@ -3218,7 +3570,7 @@ def run_model(params): ax2.plot( density_new, -x / 1000, - label=time_label, + label=f"{t_total / myr2sec(1):.1f} Myr", color=colors[-1], ) ax2.plot( @@ -3253,37 +3605,26 @@ def run_model(params): # Plot elevation history fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8)) # ax1.plot(time_list, elev_list, 'k-') - if params["plot_ma"]: - time_list = [params["t_total"] - time for time in time_list] - time_xlabel = "Time (Ma)" - time_xlim = [params["t_total"], 0.0] - else: - time_xlabel = "Time (Myr)" - time_xlim = [0.0, params["t_total"]] ax1.plot(time_list, elev_list) - ax1.set_xlabel(time_xlabel) + ax1.set_xlabel("Time (Myr)") ax1.set_ylabel("Elevation (m)") - ax1.set_xlim(time_xlim) + ax1.set_xlim(0.0, t_total / myr2sec(1)) ax1.set_title("Elevation history") # plt.axis([0.0, t_total/myr2sec(1), 0, 750]) # ax1.grid() - if params["plot_ma"]: - plot_time = params["t_total"] - time_hists[-1] / myr2sec(1) - else: - plot_time = time_hists[-1] / myr2sec(1) - ax2.plot(plot_time, vx_hist / mmyr2ms(1)) + ax2.plot(time_hists[-1] / myr2sec(1), vx_hist / mmyr2ms(1)) ax2.fill_between( - plot_time, + time_hists[-1] / myr2sec(1), vx_hist / mmyr2ms(1), 0.0, alpha=0.33, color="tab:blue", label=f"Total erosional exhumation: {exhumation_magnitude:.1f} km", ) - ax2.set_xlabel(time_xlabel) + ax2.set_xlabel("Time (Myr)") ax2.set_ylabel("Erosion rate (mm/yr)") - ax2.set_xlim(time_xlim) + ax2.set_xlim(0.0, t_total / myr2sec(1)) # if params["ero_option1"] >= 0.0: # ax2.set_ylim(ymin=0.0) # plt.axis([0.0, t_total/myr2sec(1), 0, 750]) @@ -3988,14 +4329,13 @@ def run_model(params): # Print warnings if there are multiple observed ages to write to the log file age_types = ["AHe", "AFT", "ZHe", "ZFT"] obs_age_nums = [n_obs_ahe, n_obs_aft, n_obs_zhe, n_obs_zft] - if (not params["batch_mode"]) or (not params["inverse_mode"]): - if (n_obs_ahe > 1) or (n_obs_aft > 1) or (n_obs_zhe > 1) or (n_obs_zft > 1): - print("") - for i in range(len(age_types)): - if obs_age_nums[i] > 1: - print( - f"WARNING: More than one measured {age_types[i]} age supplied, only the first was written to the output file!" - ) + if (n_obs_ahe > 1) or (n_obs_aft > 1) or (n_obs_zhe > 1) or (n_obs_zft > 1): + print("") + for i in range(len(age_types)): + if obs_age_nums[i] > 1: + print( + f"WARNING: More than one measured {age_types[i]} age supplied, only the first was written to the output file!" + ) # Open log file for writing with open(outfile, "a+") as f: @@ -4005,7 +4345,9 @@ def run_model(params): f'{params["rho_crust"]:.4f},{params["removal_fraction"]:.4f},{params["removal_start_time"]:.4f},' f'{params["removal_end_time"]:.4f},{params["ero_type"]},{params["ero_option1"]:.4f},' f'{params["ero_option2"]:.4f},{params["ero_option3"]:.4f},{params["ero_option4"]:.4f},' - f'{params["ero_option5"]:.4f},{params["ero_option6"]:.4f},{params["ero_option7"]:.4f},{params["ero_option8"]:.4f},{params["init_moho_depth"]:.4f},{init_moho_temp:.4f},' + f'{params["ero_option5"]:.4f},{params["ero_option6"]:.4f},{params["ero_option7"]:.4f},{params["ero_option8"]:.4f},' + f'{params["ero_option9"]:.4f},{params["ero_option10"]:.4f},'# BG: Added extended interval params + f'{params["init_moho_depth"]:.4f},{init_moho_temp:.4f},' f"{init_heat_flow:.4f},{elev_list[1] / kilo2base(1):.4f},{moho_depth / kilo2base(1):.4f}," f"{final_moho_temp:.4f},{final_heat_flow:.4f},{elev_list[-1] / kilo2base(1):.4f}," f'{exhumation_magnitude:.4f},{params["ap_rad"]:.4f},{params["ap_uranium"]:.4f},' @@ -4108,5 +4450,5 @@ def run_model(params): # Returns misfit for inverse_mode if "misfit" in locals(): - # print(f"- Returning misfit: {misfit}") + # print("- Returning misfit") return misfit diff --git a/tc1d/tc1d_cli.py b/tc1d/tc1d_cli.py index d874827..5f3a112 100755 --- a/tc1d/tc1d_cli.py +++ b/tc1d/tc1d_cli.py @@ -56,9 +56,10 @@ def main(): general.add_argument( "--inverse-mode", dest="inverse_mode", - help="Enable inverse mode", - action="store_true", - default=False, + help="Select inversion method: 'NA' or 'MCMC'", # BG + type=str, + choices=["NA", "MCMC"], + default="NA" ) general.add_argument( "--debug", @@ -356,6 +357,22 @@ def main(): default=[0.0], type=float, ) + erosion.add_argument( + "--ero-option9", + dest="ero_option9", + help="Erosion model option 9 (see GitHub docs)", + nargs="+", + default=[0.0], + type=float, + ) # BG: Added erosion model option 9 for extended intervals + erosion.add_argument( + "--ero-option10", + dest="ero_option10", + help="Erosion model option 10 (see GitHub docs)", + nargs="+", + default=[0.0], + type=float, + ) # BG: Added erosion model option 10 for extended intervals prediction = parser.add_argument_group( "Age prediction options", "Options for age prediction" ) @@ -571,13 +588,6 @@ def main(): action="store_true", default=False, ) - plotting.add_argument( - "--plot-myr", - dest="plot_myr", - help="Plot model time in Myr from start rather than Ma (ago)", - action="store_true", - default=False, - ) plotting.add_argument( "--plot-depth-history", dest="plot_depth_history", @@ -717,14 +727,12 @@ def main(): # - echo_ages = True if thermochronometer ages should be displayed on the screen # - plot_results = True if plots of temperatures and densities should be created # - display_plots = True if plots should be displayed on the screen - # - plot_ma = True if plots should be in millions of years ago (Ma) echo_info = not args.no_echo_info echo_thermal_info = not args.no_echo_thermal_info calc_ages = not args.no_calc_ages echo_ages = not args.no_echo_ages plot_results = not args.no_plot_results display_plots = not args.no_display_plots - plot_ma = not args.plot_myr params = { "cmd_line_call": True, @@ -736,7 +744,6 @@ def main(): "plot_results": plot_results, "save_plots": args.save_plots, "display_plots": display_plots, - "plot_ma": plot_ma, "plot_depth_history": args.plot_depth_history, "invert_tt_plot": args.invert_tt_plot, "batch_mode": args.batch_mode, @@ -770,6 +777,8 @@ def main(): "ero_option6": args.ero_option6, "ero_option7": args.ero_option7, "ero_option8": args.ero_option8, + "ero_option9": args.ero_option9, + "ero_option10": args.ero_option10, "temp_surf": args.temp_surf, "temp_base": args.temp_base, "t_total": args.time,