From 3cd16ae300bdde40e3d05acfbee271f03df5d68b Mon Sep 17 00:00:00 2001 From: asteinbe Date: Thu, 4 Oct 2018 15:32:26 +0200 Subject: [PATCH 1/2] get_result_point to backend, plotting import --- apps/beat | 49 ++++++++++++++++--------------------------- src/backend.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++ src/plotting.py | 55 ------------------------------------------------- 3 files changed, 71 insertions(+), 86 deletions(-) diff --git a/apps/beat b/apps/beat index f21132fe..0fef049f 100644 --- a/apps/beat +++ b/apps/beat @@ -12,13 +12,14 @@ import shutil from optparse import OptionParser -from beat import heart, config, utility, inputf, plotting +from beat import heart, config, utility, inputf from beat.models import load_model, Stage, estimate_hypers, sample -from beat.backend import TextChain +from beat.backend import TextChain from beat.sampler import SamplingHistory from beat.sources import MTSourceWithMagnitude from beat.utility import list2string from numpy import savez, atleast_2d +from beat import backend from pyrocko import model, util from pyrocko.trace import snuffle @@ -445,7 +446,7 @@ def command_import(args): stage.load_results( model=problem.model, stage_number=-1, load='full') - point = plotting.get_result_point(stage, problem.config, 'max') + point = backend.get_result_point(stage, problem.config, 'max') n_sources = problem.config.problem_config.n_sources source_params = problem.config.problem_config.priors.keys() @@ -482,10 +483,10 @@ def command_update(args): parser.add_option( '--parameters', - default=['structure'], type='string', + default=[], type='string', action='callback', callback=list_callback, - help='Parameters to update; "structure, hypers, hierarchicals". ' - 'Default: ["structure"] (config file-structure only)') + help='Parameters to update; "hypers, hierarchicals". ' + 'Default: [] (config file-structure only)') parser.add_option( '--mode', dest='mode', @@ -1387,17 +1388,13 @@ def command_export(args): model=problem.model, stage_number=options.stage_number, load='trace', chains=[-1]) - trace_name = 'chain--1.csv' - results_path = pjoin(problem.outfolder, config.results_dir_name) - logger.info('Saving results to %s' % results_path) - util.ensuredir(results_path) - - results_trace = pjoin(stage.handler.stage_path(-1), trace_name) - shutil.copy(results_trace, pjoin(results_path, trace_name)) - - point = plotting.get_result_point( + point = backend.get_result_point( stage, problem.config, point_llk=options.post_llk) + mpoint = dict() + for var in problem.varnames: + mpoint[var] = atleast_2d(stage.mtrace.get_values(var).mean()) + for datatype, composite in problem.composites.items(): logger.info( 'Exporting "%s" synthetics for "%s" likelihood parameters:' % ( @@ -1407,6 +1404,9 @@ def command_export(args): varname, list2string(value.ravel().tolist()))) results = composite.assemble_results(point) + results_path = pjoin(problem.outfolder, config.results_dir_name) + logger.info('Saving results to %s' % results_path) + util.ensuredir(results_path) if datatype == 'seismic': from pyrocko import io @@ -1435,7 +1435,7 @@ def command_export(args): if hasattr(sc.parameters, 'update_covariances'): if sc.parameters.update_covariances: logger.info('Saving velocity model covariance matrixes...') - composite.update_weights(point) + composite.update_weights(mpoint) for wmap in composite.wavemaps: pcovs = { list2string(dataset.nslc_id): @@ -1444,23 +1444,10 @@ def command_export(args): outname = pjoin( results_path, '%s_C_vm_%s' % ( - datatype, wmap._mapid)) - logger.info('"%s" to: %s' % (wmap._mapid, outname)) + datatype, wmap.name)) + logger.info('"%s" to: %s' % (wmap.name, outname)) savez(outname, **pcovs) - logger.info('Saving data covariance matrixes...') - for wmap in composite.wavemaps: - dcovs = { - list2string(dataset.nslc_id): - dataset.covariance.data - for dataset in wmap.datasets} - - outname = pjoin( - results_path, '%s_C_d_%s' % ( - datatype, wmap._mapid)) - logger.info('"%s" to: %s' % (wmap._mapid, outname)) - savez(outname, **dcovs) - elif datatype == 'geodetic': for ifgs, attribute in heart.results_for_export( results, datatype=datatype): diff --git a/src/backend.py b/src/backend.py index dcf631a6..cf2f9682 100644 --- a/src/backend.py +++ b/src/backend.py @@ -38,6 +38,7 @@ from beat.utility import load_objects, dump_objects, \ ListArrayOrdering, ListToArrayBijection from beat.covariance import calc_sample_covariance +from beat import utility from pyrocko import util from time import time @@ -830,6 +831,58 @@ def load_sampler_params(project_dir, stage_number, mode): sample_p_outname) return load_objects(stage_path) +def get_result_point(stage, config, point_llk='max'): + """ + Return point of a given stage result. + + Parameters + ---------- + stage : :class:`models.Stage` + config : :class:`config.BEATConfig` + point_llk : str + with specified llk(max, mean, min). + + Returns + ------- + dict + """ + if config.sampler_config.name == 'Metropolis': + if stage.step is None: + raise AttributeError( + 'Loading Metropolis results requires' + ' sampler parameters to be loaded!') + + sc = config.sampler_config.parameters + from beat.sampler.metropolis import get_trace_stats + pdict, _ = get_trace_stats( + stage.mtrace, stage.step, sc.burn, sc.thin) + point = pdict[point_llk] + + elif config.sampler_config.name == 'SMC': + llk = stage.mtrace.get_values( + varname='like', + combine=True) + + posterior_idxs = utility.get_fit_indexes(llk) + + point = stage.mtrace.point(idx=posterior_idxs[point_llk]) + + elif config.sampler_config.name == 'PT': + params = config.sampler_config.parameters + llk = stage.mtrace.get_values( + varname='like', + burn=int(params.n_samples * params.burn), + thin=params.thin) + + posterior_idxs = utility.get_fit_indexes(llk) + + point = stage.mtrace.point(idx=posterior_idxs[point_llk]) + + else: + raise NotImplementedError( + 'Sampler "%s" is not supported!' % config.sampler_config.name) + + return point def concatenate_traces(mtraces): """ diff --git a/src/plotting.py b/src/plotting.py index 30986e65..b6113ff4 100644 --- a/src/plotting.py +++ b/src/plotting.py @@ -10,7 +10,6 @@ from beat import utility from beat.models import Stage -from beat.sampler.metropolis import get_trace_stats from beat.heart import init_seismic_targets, init_geodetic_targets from beat.colormap import slip_colormap @@ -473,60 +472,6 @@ def plot_log_cov(cov_mat): plt.colorbar(im) plt.show() - -def get_result_point(stage, config, point_llk='max'): - """ - Return point of a given stage result. - - Parameters - ---------- - stage : :class:`models.Stage` - config : :class:`config.BEATConfig` - point_llk : str - with specified llk(max, mean, min). - - Returns - ------- - dict - """ - if config.sampler_config.name == 'Metropolis': - if stage.step is None: - raise AttributeError( - 'Loading Metropolis results requires' - ' sampler parameters to be loaded!') - - sc = config.sampler_config.parameters - pdict, _ = get_trace_stats( - stage.mtrace, stage.step, sc.burn, sc.thin) - point = pdict[point_llk] - - elif config.sampler_config.name == 'SMC': - llk = stage.mtrace.get_values( - varname='like', - combine=True) - - posterior_idxs = utility.get_fit_indexes(llk) - - point = stage.mtrace.point(idx=posterior_idxs[point_llk]) - - elif config.sampler_config.name == 'PT': - params = config.sampler_config.parameters - llk = stage.mtrace.get_values( - varname='like', - burn=int(params.n_samples * params.burn), - thin=params.thin) - - posterior_idxs = utility.get_fit_indexes(llk) - - point = stage.mtrace.point(idx=posterior_idxs[point_llk]) - - else: - raise NotImplementedError( - 'Sampler "%s" is not supported!' % config.sampler_config.name) - - return point - - def plot_quadtree(ax, data, target, cmap, colim, alpha=0.8): """ Plot UnwrappedIFG displacements on the respective quadtree rectangle. From f6349491dcf2841e782a3756d256f399cea1a7d7 Mon Sep 17 00:00:00 2001 From: asteinbe Date: Thu, 4 Oct 2018 15:33:33 +0200 Subject: [PATCH 2/2] get_result_point to backend, plotting import --- apps/beat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/beat b/apps/beat index 0fef049f..196ea4f2 100644 --- a/apps/beat +++ b/apps/beat @@ -1060,6 +1060,7 @@ def command_build_gfs(args): def command_plot(args): + from beat import plotting command_str = 'plot' def setup(parser): @@ -1152,7 +1153,6 @@ def command_plot(args): dest='build', action='store_true', help='Build models during problem loading.') - plots_avail = plotting.available_plots() details = '''Available are: %s or "all". Multiple plots can be