# ------------------------------------------------------------------------------ #
# Old plotting helpers that are still used by some of Jonas' examples
# moving to plot.py
# ------------------------------------------------------------------------------ #
import logging
import datetime
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
log = logging.getLogger(__name__)
[docs]def get_all_free_RVs_names(model):
"""
Returns the names of all free parameters of the model
Parameters
----------
model: pm.Model instance
Returns
-------
: list
all variable names
"""
varnames = [str(x).replace("_log__", "") for x in model.free_RVs]
return varnames
[docs]def get_prior_distribution(model, x, varname):
"""
Given a model and variable name, get the prior that was used for modeling.
Parameters
----------
model: pm.Model instance
x: list or array
varname: string
Returns
-------
: array
the prior distribution evaluated at x
"""
return np.exp(model[varname].distribution.logp(x).eval())
[docs]def plot_hist(model, trace, ax, varname, colors=("tab:blue", "tab:orange"), bins=50):
"""
Plots one histogram of the prior and posterior distribution of the variable varname.
Parameters
----------
model: pm.Model instance
trace: trace of the model
ax: matplotlib.axes instance
varname: string
colors: list with 2 colornames
bins: number or array
passed to np.hist
Returns
-------
None
"""
if len(trace[varname].shape) >= 2:
print("Dimension of {} larger than one, skipping".format(varname))
ax.set_visible(False)
return
ax.hist(trace[varname], bins=bins, density=True, color=colors[1], label="Posterior")
limits = ax.get_xlim()
x = np.linspace(*limits, num=100)
try:
ax.plot(
x,
get_prior_distribution(model, x, varname),
label="Prior",
color=colors[0],
linewidth=3,
)
except:
pass
ax.set_xlim(*limits)
ax.set_ylabel("Density")
ax.set_xlabel(varname)
[docs]def plot_cases(
trace,
new_cases_obs,
date_begin_sim,
diff_data_sim,
start_date_plot=None,
end_date_plot=None,
ylim=None,
week_interval=None,
colors=("tab:blue", "tab:orange"),
country="Germany",
):
"""
Plots the new cases, the fit, forecast and lambda_t evolution
Parameters
----------
trace : trace returned by model
new_cases_obs : array
date_begin_sim : datetime.datetime
diff_data_sim : float
Difference in days between the begin of the simulation and the data
start_date_plot : datetime.datetime
end_date_plot : datetime.datetime
ylim : float
the maximal y value to be plotted
week_interval : int
the interval in weeks of the y ticks
colors : list with 2 colornames
Returns
-------
figure, axes
"""
def conv_time_to_mpl_dates(arr):
return matplotlib.dates.date2num(
[datetime.timedelta(days=float(date)) + date_begin_sim for date in arr]
)
new_cases_sim = trace["new_cases"]
len_sim = trace["lambda_t"].shape[1]
if start_date_plot is None:
start_date_plot = date_begin_sim + datetime.timedelta(days=diff_data_sim)
if end_date_plot is None:
end_date_plot = date_begin_sim + datetime.timedelta(days=len_sim)
if ylim is None:
ylim = 1.6 * np.max(new_cases_obs)
num_days_data = len(new_cases_obs)
diff_to_0 = num_days_data + diff_data_sim
date_data_end = date_begin_sim + datetime.timedelta(
days=diff_data_sim + num_days_data
)
num_days_future = (end_date_plot - date_data_end).days
start_date_mpl, end_date_mpl = matplotlib.dates.date2num(
[start_date_plot, end_date_plot]
)
if week_interval is None:
week_inter_left = int(np.ceil(num_days_data / 7 / 5))
week_inter_right = int(np.ceil((end_date_mpl - start_date_mpl) / 7 / 6))
else:
week_inter_left = week_interval
week_inter_right = week_interval
fig, axes = plt.subplots(
2,
2,
figsize=(9, 5),
gridspec_kw={"height_ratios": [1, 3], "width_ratios": [2, 3]},
)
ax = axes[1][0]
time_arr = np.arange(-len(new_cases_obs), 0)
mpl_dates = conv_time_to_mpl_dates(time_arr) + diff_data_sim + num_days_data
ax.plot(
mpl_dates,
new_cases_obs,
"d",
markersize=6,
label="Data",
zorder=5,
color=colors[0],
)
new_cases_past = new_cases_sim[:, :num_days_data]
percentiles = (
np.percentile(new_cases_past, q=2.5, axis=0),
np.percentile(new_cases_past, q=97.5, axis=0),
)
ax.plot(
mpl_dates,
np.median(new_cases_past, axis=0),
color=colors[1],
label="Fit (with 95% CI)",
)
ax.fill_between(
mpl_dates, percentiles[0], percentiles[1], alpha=0.3, color=colors[1]
)
ax.set_yscale("log")
ax.set_ylabel("Number of new cases")
ax.set_xlabel("Date")
ax.legend()
ax.xaxis.set_major_locator(
matplotlib.dates.WeekdayLocator(
interval=week_inter_left, byweekday=matplotlib.dates.SU
)
)
ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%m/%d"))
ax.set_xlim(start_date_mpl)
ax = axes[1][1]
time1 = np.arange(-len(new_cases_obs), 0)
mpl_dates = conv_time_to_mpl_dates(time1) + diff_data_sim + num_days_data
ax.plot(
mpl_dates,
new_cases_obs,
"d",
label="Data",
markersize=4,
color=colors[0],
zorder=5,
)
new_cases_past = new_cases_sim[:, :num_days_data]
ax.plot(
mpl_dates,
np.median(new_cases_past, axis=0),
"--",
color=colors[1],
linewidth=1.5,
label="Fit with 95% CI",
)
percentiles = (
np.percentile(new_cases_past, q=2.5, axis=0),
np.percentile(new_cases_past, q=97.5, axis=0),
)
ax.fill_between(
mpl_dates, percentiles[0], percentiles[1], alpha=0.2, color=colors[1]
)
time2 = np.arange(0, num_days_future)
mpl_dates_fut = conv_time_to_mpl_dates(time2) + diff_data_sim + num_days_data
cases_future = new_cases_sim[:, num_days_data : num_days_data + num_days_future].T
median = np.median(cases_future, axis=-1)
percentiles = (
np.percentile(cases_future, q=2.5, axis=-1),
np.percentile(cases_future, q=97.5, axis=-1),
)
ax.plot(
mpl_dates_fut,
median,
color=colors[1],
linewidth=3,
label="forecast with 75% and 95% CI",
)
ax.fill_between(
mpl_dates_fut, percentiles[0], percentiles[1], alpha=0.1, color=colors[1]
)
ax.fill_between(
mpl_dates_fut,
np.percentile(cases_future, q=12.5, axis=-1),
np.percentile(cases_future, q=87.5, axis=-1),
alpha=0.2,
color=colors[1],
)
ax.set_xlabel("Date")
ax.set_ylabel(f"New confirmed cases in {country}")
ax.legend(loc="upper left")
ax.set_ylim(0, ylim)
func_format = lambda num, _: "${:.0f}\,$k".format(num / 1_000)
ax.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(func_format))
ax.set_xlim(start_date_mpl, end_date_mpl)
ax.xaxis.set_major_locator(
matplotlib.dates.WeekdayLocator(
interval=week_inter_right, byweekday=matplotlib.dates.SU
)
)
ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%m/%d"))
ax = axes[0][1]
time = np.arange(-diff_to_0, -diff_to_0 + len_sim)
lambda_t = trace["lambda_t"][:, :]
μ = trace["mu"][:, None]
mpl_dates = conv_time_to_mpl_dates(time) + diff_data_sim + num_days_data
ax.plot(mpl_dates, np.median(lambda_t - μ, axis=0), color=colors[1], linewidth=2)
ax.fill_between(
mpl_dates,
np.percentile(lambda_t - μ, q=2.5, axis=0),
np.percentile(lambda_t - μ, q=97.5, axis=0),
alpha=0.15,
color=colors[1],
)
ax.set_ylabel("effective\ngrowth rate $\lambda_t^*$")
# ax.set_ylim(-0.15, 0.45)
ylims = ax.get_ylim()
ax.hlines(0, start_date_mpl, end_date_mpl, linestyles=":")
delay = matplotlib.dates.date2num(date_data_end) - np.percentile(
trace["delay"], q=75
)
ax.vlines(delay, ylims[0], ylims[1], linestyles="-", colors=["tab:red"])
ax.set_ylim(*ylims)
ax.text(
delay + 0.5,
ylims[1] - 0.04 * np.diff(ylims),
"unconstrained because\nof reporting delay",
color="tab:red",
verticalalignment="top",
)
ax.text(
delay - 0.5,
ylims[1] - 0.04 * np.diff(ylims),
"constrained\nby data",
color="tab:red",
horizontalalignment="right",
verticalalignment="top",
)
ax.xaxis.set_major_locator(
matplotlib.dates.WeekdayLocator(
interval=week_inter_right, byweekday=matplotlib.dates.SU
)
)
ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%m/%d"))
ax.set_xlim(start_date_mpl, end_date_mpl)
axes[0][0].set_visible(False)
plt.subplots_adjust(wspace=0.4, hspace=0.3)
return fig, axes