Source code for covid19_inference.sampling

import random
import logging
import warnings
from collections import Counter
import pickle
import glob
import os
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import numpy as np


log = logging.getLogger(__name__)


[docs]def get_start_points(trace, trace_az, frames_start=None, SD_chain_logl=2.5): r""" Returns the starting points such that the chains deviate at most SD_chain_logl standard deviations from the chain with the highest likelihood. Parameters ---------- trace : multitrace object trace_az : arviz trace object frames_start : int Which frames to use for calculating the mean likelihood and its standard deviation. By default it is set to the last third of the tuning samples SD_chain_logl : None or float The number of standard deviations. 2.5 as default. If None, keep all chains Returns ------- start_points : A list of starting points logl_mean : The mean log-likelihood of the starting points """ logl = trace_az.warmup_sample_stats["lp"] n_tune = logl.shape[1] n_chains = logl.shape[0] if frames_start is None: frames_start = 3 * n_tune // 4 logl_mean = np.array(logl[:, frames_start:].mean(axis=1)) logl_std = np.array(logl[:, frames_start:].std(axis=1)) max_idx = np.argmax(logl_mean) if SD_chain_logl is not None: logl_thr = logl_mean[max_idx] - logl_std[max_idx] * SD_chain_logl keep_chains = logl_mean >= logl_thr log.info(f"Num chains kept: {np.sum(keep_chains)}/{n_chains}") else: keep_chains = np.ones_like(logl_mean).astype("bool") start_points = [] for i, keep_chain in enumerate(keep_chains): if keep_chain: start_points.append(trace.point(-1, chain=i)) return start_points, logl_mean[keep_chains]
class Callback: """ Simple callback to save the trace every n iterations and plot the logp. Parameters ---------- path : str Path to save the trace name : str Name of the model, should be used when running multiple models in parallel (default: "model") n : int Save the trace every n iterations """ def __init__(self, path="/temp", name="model", n=100): self.path = path self.name = name self.n = n self.lengths = Counter() # Setup plotting of logp self.fig, self.ax = plt.subplots(1, 1, figsize=(10, 5)) self.ax.set_xlabel("Iteration") self.ax.set_ylabel("Logp") self.ax.set_title(name) """This function is called by pymc3 every iterations """ def __call__(self, trace, draw): # Update values self.lengths[draw.chain] += 1 # Save the trace every n iterations if self.lengths[draw.chain] % self.n == 0: self.save(trace, draw.chain) self.plot_logp(trace, draw.chain) def plot_logp(self, trace, chain): pm_trace = pm.backends.base.MultiTrace({chain: trace}.values()) masked_logp = pm_trace["model_logp"][pm_trace["model_logp"] != 0] try: self.ax.plot( np.arange(self.lengths[chain] - len(masked_logp), self.lengths[chain]), masked_logp, ) self.fig.savefig(os.path.join(self.path, self.name) + "_logp.png") except: log.warning(f"Could not save {os.path.join(self.path, self.name)}_logp.png") def save(self, trace, chain): try: with open(os.path.join(self.path, self.name) + f"_{chain}.pkl", "wb") as f: trace.chain = chain pickle.dump(trace, f) except: log.warning( f"Could not save {os.path.join(self.path, self.name)}_{chain}.pkl" ) def load_all(self): files = glob.glob(f"{self.path}/{self.name}_*.pkl") traces = {} for f in files: with open(f, "rb") as f: trace = pickle.load(f) traces[trace.chain] = trace return az.from_pymc3(pm.backends.base.MultiTrace(traces.values())) def burn_in( model, n_tune, n_chains_burn_in, n_chains_final, step_method, start_points=None, callback=None, args_start_points=None, sample_kwargs=None, ): if sample_kwargs is None: sample_kwargs = {} i = 0 while i < 50: try: trace_tuning = pm.sample( model=model, tune=n_tune, draws=0, start=start_points, chains=n_chains_burn_in, return_inferencedata=False, discard_tuned_samples=False, step=step_method, callback=callback, **sample_kwargs, ) except RuntimeError as error: if i < 10: i += 1 log.warning( f"Tuning lead to a nan error in one chain, " f"trying again (try no {i})." ) continue else: raise error i = 1000 trace_tuning_az = az.from_pymc3(trace_tuning, model=model, save_warmup=True) if args_start_points is None: args_start_points = {} start_points, logl_starting_points = get_start_points( trace_tuning, trace_tuning_az, SD_chain_logl=None, **args_start_points ) num_start_points = len(start_points) if num_start_points < n_chains_final: log.warning( "Not enough chains converged to minimum, we recommend increasing the number of tuning chains" ) start_points = random.choices(start_points, k=n_chains_final) elif num_start_points > n_chains_final: p = np.exp(logl_starting_points - max(logl_starting_points)) start_points = np.random.choice( start_points, size=n_chains_final, p=p / np.sum(p), replace=False, ) return start_points, trace_tuning_az
[docs]def robust_sample( model, tune, draws, final_chains, burnin_chains, burnin_draws=None, burnin_chains_2nd=None, burnin_draws_2nd=None, args_start_points=None, callback=None, sample_kwargs=None, **kwargs, ): r""" Samples the model by starting more chains than needed (burn-in chains) and using only a reduced number final_chains for the final sampling. The final chains are randomly chosen (without replacement) weighted by their likelihood. Parameters ---------- model : :class:`Cov19Model` The model tune : int Number of tuning samples draws : int Number of final samples final_chains : int Number of draw chains burnin_chains : int Number of chains used during burn-in, recommended to use about 2-3 time more than the number of final_chains burnin_draws : int Length of the burn-in period, can be fairly short, on the order of a few hundreds draws. By default it set to tune//2 burnin_chains_2nd : int If not None, use a two-stage burn-in period, reducing the number of chains each time, Therefore, it should be less than burnin_chains and more than final_chains: burnin_chains > burnin_chains_2nd > final_chains burnin_draws_2nd : int Length of the second burn-in period. By default it set burnin_draws args_start_points : dict Arguments passed to `get_start_points` tune_2nd : int If set, use different number of tuning samples for the second tuning sample_kwargs: Arguments passed to pm.sample **kwargs : Arguments passed to the nuts step function. Returns ------- trace : trace as multitrace object trace_az : trace as arviz object """ burnin_chains_2nd_to_compare = ( burnin_chains_2nd if burnin_chains_2nd is not None else final_chains + 0.5 ) if not burnin_chains > burnin_chains_2nd_to_compare > final_chains: raise RuntimeError( "The number of chains should decrease for good sampline: " "burnin_chains > burnin_chains_2nd > final_chains" ) if burnin_draws is None: burnin_draws = tune // 2 if burnin_draws_2nd is None and burnin_chains_2nd is not None: burnin_draws_2nd = burnin_draws with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*invalid value encountered in double_scalars.*" ) warnings.filterwarnings( "ignore", message=".*Tuning samples will be included in the returned.*" ) warnings.filterwarnings( "ignore", message=".*Tuning was enabled throughout the whole trace.*" ) warnings.filterwarnings("ignore", message=".*Mean of empty slice.*") warnings.filterwarnings( "ignore", message=".*The number of samples is too small to check convergence reliably.*", ) # Create nuts step method class to reuse for tuning and later sampling default_nuts = pm.NUTS(model=model, **kwargs) if sample_kwargs is None: sample_kwargs = {} ## Burn-in start_points, trace_burn_in_1_az = burn_in( model, n_tune=burnin_draws, n_chains_burn_in=burnin_chains, n_chains_final=final_chains if burnin_chains_2nd is None else burnin_chains_2nd, step_method=default_nuts, args_start_points=args_start_points, sample_kwargs=sample_kwargs, ) if burnin_chains_2nd is not None: start_points, trace_burn_in_2_az = burn_in( model, n_tune=burnin_draws_2nd, n_chains_burn_in=burnin_chains_2nd, n_chains_final=final_chains, start_points=start_points, step_method=default_nuts, args_start_points=args_start_points, sample_kwargs=sample_kwargs, ) trace = pm.sample( model=model, tune=tune, draws=draws, chains=final_chains, start=start_points, return_inferencedata=False, discard_tuned_samples=False, step=default_nuts, callback=callback, **sample_kwargs, ) trace_az = az.from_pymc3(trace, model=model, save_warmup=True) def append_burn_in_samples(trace, trace_burn_in, num): new_names = [ f"burn-in-{num}_" + n.replace("warmup_", "") for n in trace_burn_in.groups() ] with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="The group .* is not defined in the .* scheme", ) trace.add_groups( {n: o for n, o in zip(new_names, trace_burn_in.values())} ) return trace trace_az = append_burn_in_samples(trace_az, trace_burn_in_1_az, 1) if burnin_chains_2nd is not None: trace_az = append_burn_in_samples(trace_az, trace_burn_in_2_az, 2) return trace, trace_az