Source code for branchpro.apps._inference

#
# InferenceApp
#
# This file is part of BRANCHPRO
# (https://github.com/SABS-R3-Epidemiology/branchpro.git) which is released
# under the BSD 3-clause license. See accompanying LICENSE.md for copyright
# notice and full license details.
#
from math import floor
import pandas as pd

import dash
import dash_bootstrap_components as dbc
from dash import dcc
from dash import html

import branchpro as bp
from branchpro.apps import BranchProDashApp


[docs] class BranchProInferenceApp(BranchProDashApp): """BranchProInferenceApp Class: Class for the inference dash app with figure and sliders for the BranchPro models. """ def __init__(self, long_callback_manager=None): """ Parameters ---------- long_callback_manager Optional callback manager for long callbacks. See https://dash.plotly.com/long-callbacks """ super(BranchProInferenceApp, self).__init__() self.app = dash.Dash(__name__, external_stylesheets=self.css, long_callback_manager=long_callback_manager) self.app.title = 'BranchproInf' self.session_data = { 'data_storage': None, 'interval_storage': None, 'posterior_storage': None} button_style = { 'width': '100%', 'height': '60px', 'lineHeight': '60px', 'borderWidth': '1px', 'borderStyle': 'dashed', 'borderRadius': '5px', 'textAlign': 'center', 'margin': '10px' } self.app.layout = html.Div([ dbc.Container( [ html.H1('Branching Processes', id='page-title'), html.Div([]), # Empty div for top explanation texts html.H2('Incidence Data'), dbc.Row( dbc.Col(dcc.Graph( figure=bp.IncidenceNumberPlot().figure, id='data-fig')) ), dbc.Row( [ dbc.Col( children=[ html.H6([ 'You can upload your own ', html.Span( 'incidence data', id='inc-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer'}, ), ' here. It will appear as bars.']), dbc.Modal( self._inc_modal, id='inc_modal', size='xl', ), html.Div([ 'Data must be in the following column ' 'format: `Time`, `Incidence number`, ' '`Imported Cases` (optional), ' '`R_t` (true value of R, optional).']), dcc.Upload( id='upload-data', children=html.Div( [ 'Drag and Drop or ', html.A( 'Select Files', style={ 'text-decoration': 'underline'}), ' to upload your Incidence \ Number data.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True ), html.Div(id='incidence-data-upload')]), dbc.Col( children=[ html.H6([ 'You can upload your own ', html.Span( 'serial interval', id='si-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer'} ), ' here.' ]), dbc.Modal( self._si_modal, id='si_modal', size='lg', ), html.Div([ 'Data must contain one or more serial ' 'intervals to be used for constructing' ' the posterior distributions each ' 'included as a column.']), dcc.Upload( id='upload-interval', children=html.Div( [ 'Drag and Drop or ', html.A( 'Select Files', style={ 'text-decoration': '\ underline'}), ' to upload your Serial \ Interval.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True ), html.Div(id='ser-interval-upload')]) ], align='center', ), html.H2('Plot of R values'), html.Progress(id='progress_bar'), html.Div(id='first_run', # see flip_first_run() in the app children='True', style={'display': 'none'}), dbc.Row( [ dbc.Col( children=dcc.Graph( figure=bp.ReproductionNumberPlot( ).figure, id='posterior-fig', style={'display': 'block'})), dbc.Col(self.update_sliders(), id='all-sliders') ], align='center', ), html.Div([]), # Empty div for bottom text html.Div(id='data_storage', style={'display': 'none'}), html.Div(id='interval_storage', style={'display': 'none'}), html.Div(id='posterior_storage', style={'display': 'none'}) ], fluid=True), self.mathjax_script]) # Set the app index string for mathjax self.app.index_string = self.mathjax_html # Save the locations of texts from the layout self.main_text = self.app.layout.children[0].children[1].children self.collapsed_text = self.app.layout.children[0].children[-4].children
[docs] def update_sliders(self, mean=5.0, stdev=5.0, tau=6, central_prob=.95): """Generate sliders for the app. Parameters ---------- mean (float) start position on the slider for the mean of the prior for the Branch Pro model in the posterior. stdev (float) start position on the slider for the standard deviation of the prior for the Branch Pro model in the posterior. tau (int) start position on the slider for the tau window used in the running of the inference of the reproduction numbers of the Branch Pro model in the posterior. central_prob (float) start position on the slider for the level of the computed credible interval of the estimated R number values. Returns ------- html.Div A dash html component containing the sliders """ data = self.session_data.get('data_storage') if data is not None: time_label, inc_label = data.columns[:2] times = data[time_label] max_tau = floor((times.max() - times.min() + 1)/3) if tau > max_tau: # If default value of tau exceeds maximum accepted # choose tau to be this maximum value tau = max_tau else: max_tau = 7 sliders = bp._SliderComponent() if (data is not None) and ('Imported Cases' in data.columns): # Add slider for epsilon only when imported cases are detected # in the data with default assuming equal R numbers for local # and imported cases sliders.add_slider( 'Epsilon', 'epsilon', 1.0, 0.0, 3.0, 0.01) else: sliders.add_slider( 'Epsilon', 'epsilon', 1.0, 0.0, 3.0, 0.01, invisible=True) sliders.add_slider( 'Prior Mean', 'mean', mean, 0.1, 10.0, 0.01) sliders.add_slider( 'Prior Standard Deviation', 'stdev', stdev, 0.1, 10.0, 0.01) sliders.add_slider( 'Inference Sliding Window', 'tau', tau, 0, max_tau, 1, as_integer=True) sliders.add_slider( 'Central Posterior Probability', 'central_prob', central_prob, 0.1, 0.99, 0.01) return sliders.get_sliders_div()
[docs] def update_posterior(self, mean, stdev, tau, central_prob, epsilon=None, progress_fn=None): """Update the posterior distribution based on slider values. Parameters ---------- mean (float) updated position on the slider for the mean of the prior for the Branch Pro model in the posterior. stdev (float) updated position on the slider for the standard deviation of the prior for the Branch Pro model in the posterior. tau (int) updated position on the slider for the tau window used in the running of the inference of the reproduction numbers of the Branch Pro model in the posterior. central_prob (float) updated position on the slider for the level of the computed credible interval of the estimated R number values. epsilon (float) updated position on the slider for the constant of proportionality between local and imported cases for the Branch Pro model in the posterior. progress_fn Function of integer argument to send to posterior run_inference. It can be used for dash callbacks set_progress (see update_posterior_storage in the app script) Returns ------- pandas.DataFrame The posterior distribution, summarized in a dataframe with the following columns: 'Time Points', 'Mean', 'Lower bound CI' and 'Upper bound CI' """ new_alpha = (mean / stdev) ** 2 new_beta = mean / (stdev ** 2) data = self.session_data.get('data_storage') if data is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = data.columns[:2] num_cols = len(self.session_data.get('interval_storage').columns) prior_params = (new_alpha, new_beta) labels = {'time_key': time_label, 'inc_key': inc_label} if num_cols == 1: serial_interval = self.session_data.get( 'interval_storage').iloc[:, 0].values if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) # Posterior follows the LocImp behaviour posterior = bp.LocImpBranchProPosterior( data, imported_data, epsilon, serial_interval, *prior_params, **labels) else: # Posterior follows the simple behaviour posterior = bp.BranchProPosterior( data, serial_interval, *prior_params, **labels) posterior.run_inference(tau) else: serial_intervals = self.session_data.get( 'interval_storage').values.T if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) # Posterior follows the LocImp behaviour posterior = bp.LocImpBranchProPosteriorMultSI( data, imported_data, epsilon, serial_intervals, *prior_params, **labels) else: # Posterior follows the simple behaviour posterior = bp.BranchProPosteriorMultSI( data, serial_intervals, *prior_params, **labels) posterior.run_inference(tau, progress_fn=progress_fn) return posterior.get_intervals(central_prob)
[docs] def update_inference_figure(self, source=None): """Update the inference figure based on currently stored information. Parameters ---------- source : str Dash callback source Returns ------- plotly.Figure Figure with updated posterior distribution """ data = self.session_data.get('data_storage') posterior = self.session_data.get('posterior_storage') if data is None or posterior is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = data.columns[:2] plot = bp.ReproductionNumberPlot() plot.add_interval_rt(posterior) if 'R_t' in data.columns: plot.add_ground_truth_rt( data[[time_label, 'R_t']], time_key=time_label, r_key='R_t') # Keeps traces visibility states fixed when changing sliders plot.figure['layout']['legend']['uirevision'] = True return plot.figure
[docs] def update_data_figure(self): """Update the data figure based on currently stored information. Returns ------- plotly.Figure Figure with updated data """ data = self.session_data.get('data_storage') if data is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = data.columns[:2] plot = bp.IncidenceNumberPlot() if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) # Bar plot of local cases plot.add_data( data, time_key=time_label, inc_key=inc_label, name='Local Cases') # Bar plot of imported cases plot.add_data( imported_data, time_key=time_label, inc_key=inc_label, name='Imported Cases') else: # If no imported cases are present plot.add_data(data, time_key=time_label, inc_key=inc_label) # Keeps traces visibility states fixed when changing sliders plot.figure['layout']['legend']['uirevision'] = True return plot.figure