Source code for branchpro.apps._simulation

#
# SimulationApp
#
# This file is part of BRANCHPRO
# (https://github.com/SABS-R3-Epidemiology/branchpro.git) which is released
# under the BSD 3-clause license. See accompanying LICENSE.md for copyright
# notice and full license details.
#

import copy
import numpy as np
import pandas as pd
import dash
import dash_bootstrap_components as dbc
from dash import dcc
from dash import html

import branchpro as bp
from branchpro.apps import BranchProDashApp


[docs] class IncidenceNumberSimulationApp(BranchProDashApp): """IncidenceNumberSimulationApp Class: Class for the simulation dash app with figure and sliders for the BranchPro models. """ def __init__(self): super().__init__() self.session_data = { 'data_storage': None, 'interval_storage': None} self.app = dash.Dash(__name__, external_stylesheets=self.css) self.app.title = 'BranchproSim' button_style = { 'width': '100%', 'height': '60px', 'lineHeight': '60px', 'borderWidth': '1px', 'borderStyle': 'dashed', 'borderRadius': '5px', 'textAlign': 'center', 'margin': '10px' } self.app.layout = \ html.Div([ dbc.Container([ html.H1('Branching Processes'), html.Div([]), # Empty div for top explanation texts dbc.Row([ dbc.Col([ html.Button( 'Add new simulation', id='sim-button', n_clicks=0), dcc.Graph( figure=bp.IncidenceNumberPlot().figure, id='myfig') ]), dbc.Col( self.update_sliders(), id='all-sliders') ], align='center'), dbc.Row( [ dbc.Col( children=[ html.H4([ 'You can upload your own ', html.Span( 'incidence data', id='inc-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer'}, ), ' here.' ]), dbc.Modal( self._inc_modal, id='inc_modal', size='xl', ), html.Div([ 'It will appear as bars, while' ' the simulation will be a line.' ' You can upload both local and ' '/ or imported incidence data.' ]), dcc.Upload( id='upload-data', children=html.Div([ 'Drag and Drop or ', html.A( 'Select Files', style={ 'text-decoration': 'underline'}), ' to upload your Incidence Number ' 'data.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True ), html.Div(id='incidence-data-upload')]), dbc.Col( children=[ html.H4([ 'You can upload your own ', html.Span( 'serial interval', id='si-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer'} ), ' here.' ]), dbc.Modal( self._si_modal, id='si_modal', size='lg', ), html.Div([ 'Data must contain one serial ' 'interval to be used for simulation' ' displayed as a column. If multiple ' 'serial intervals are uploaded, the ' 'first one will be used.']), dcc.Upload( id='upload-interval', children=html.Div( [ 'Drag and Drop or ', html.A( 'Select Files', style={ 'text-decoration': '\ underline'}), ' to upload your Serial \ Interval.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True ), html.Div(id='ser-interval-upload')]) ], align='center', ), html.Div([]), # Empty div for bottom text html.Div(id='data_storage', style={'display': 'none'}), html.Div(id='interval_storage', style={'display': 'none'}), dcc.ConfirmDialog( id='confirm', message='Simulation failed due to overflow!', ), ], fluid=True), self.mathjax_script ]) # Set the app index string for mathjax self.app.index_string = self.mathjax_html # Save the locations of texts from the layout self.main_text = self.app.layout.children[0].children[1].children self.collapsed_text = self.app.layout.children[0].children[-4].children
[docs] def update_sliders(self, init_cond=10.0, r0=1.0, r1=0.5, magnitude_init_cond=None): """Generate sliders for the app. This method tunes the bounds of the sliders to the time period and magnitude of the data. Parameters ---------- init_cond : int start position on the slider for the number of initial cases for the Branch Pro model in the simulator. r0 : float start position on the slider for the initial reproduction number for the Branch Pro model in the simulator. r1 : float start position on the slider for the second reproduction number for the Branch Pro model in the simulator. magnitude_init_cond : int maximal start position on the slider for the number of initial cases for the Branch Pro model in the simulator. By default, it will be set to the maximum value observed in the data. Returns ------- html.Div A dash html component containing the sliders """ data = self.session_data.get('data_storage') # Calculate slider values that depend on the data if data is not None: time_label, inc_label = data.columns[:2] if magnitude_init_cond is None: magnitude_init_cond = max(data[inc_label]) bounds = (1, max(data[time_label])) else: # choose values to use if there is no data if magnitude_init_cond is None: magnitude_init_cond = 1000 bounds = (1, 30) mid_point = round(sum(bounds) / 2) # Make new sliders sliders = bp._SliderComponent() if (data is not None) and ('Imported Cases' in data.columns): # Add slider for epsilon only when imported cases are detected # in the data with default assuming equal R numbers for local # and imported cases sliders.add_slider( 'Epsilon', 'epsilon', 1.0, 0.0, 3.0, 0.01) else: sliders.add_slider( 'Epsilon', 'epsilon', 1.0, 0.0, 3.0, 0.01, invisible=True) sliders.add_slider( 'Initial Cases', 'init_cond', init_cond, 0.0, magnitude_init_cond, 1, as_integer=True) sliders.add_slider('Initial R', 'r0', r0, 0.1, 10.0, 0.01) sliders.add_slider('Second R', 'r1', r1, 0.1, 10.0, 0.01) sliders.add_slider( 'Time of change', 't1', mid_point, bounds[0], bounds[1], 1, as_integer=True) return sliders.get_sliders_div()
[docs] def update_figure(self, fig=None, simulations=None, source=None): """Generate a plotly figure of incidence numbers and simulated cases. By default, this method uses the information saved in self.session_data to populate the figure with data. If a current figure and dash callback source are passed, it will try to just update the existing figure for speed improvements. Parameters ---------- fig : dict Current copy of the figure simulations : pd.DataFrame Simulation trajectories to add to the figure. source : str Dash callback source Returns ------- plotly.Figure Figure with updated data and simulations """ data = self.session_data.get('data_storage') if data is None: raise dash.exceptions.PreventUpdate() if fig is not None and simulations is not None: # Check if there is a faster way to update the figure if len(fig['data']) > 0 and source in ['epsilon', 'init_cond', 'r0', 'r1', 't1']: # Clear all traces except one simulation and the data if ('Imported Cases' in data.columns) and ( 'Incidence Number' in data.columns): fig['data'] = [fig['data'][0], fig['data'][1], fig['data'][-1]] else: fig['data'] = [fig['data'][0], fig['data'][-1]] # Set the y values of that trace equal to an updated simulation fig['data'][-1]['y'] = simulations.iloc[:, -1] return fig elif len(fig['data']) > 0 and source == 'sim-button': # Add one extra simulation, and set its y values fig['data'].append(copy.deepcopy(fig['data'][-1])) fig['data'][-1]['y'] = simulations.iloc[:, -1] if ('Imported Cases' in data.columns) and ( 'Incidence Number' in data.columns): sim_tuple = range(1, len(fig['data'])-2) else: sim_tuple = range(len(fig['data'])-2) for i in sim_tuple: # Change opacity of all traces in the figure but for the # first - the barplot of incidences # last - the latest simulation fig['data'][i+1]['line']['color'] = 'rgba(255,0,0,0.25)' fig['data'][i+1]['showlegend'] = False return fig time_label, inc_label = (data.columns[0], 'Incidence Number') num_simulations = len(simulations.columns) - 1 # Make a new figure plot = bp.IncidenceNumberPlot() if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) if 'Incidence Number' in data.columns: # Bar plot of local cases plot.add_data( data.iloc[:, :2], time_key=time_label, inc_key=inc_label, name='Local Cases') # Bar plot of imported cases plot.add_data( imported_data, time_key=time_label, inc_key=inc_label, name='Imported Cases') else: # If no imported cases are present plot.add_data(data, time_key=time_label, inc_key=inc_label) # Keeps traces visibility states fixed when changing sliders plot.figure['layout']['legend']['uirevision'] = True for sim in range(num_simulations): df = simulations.iloc[:, [0, sim+1]] df.columns = [time_label, inc_label] plot.add_simulation(df, time_key=time_label, inc_key=inc_label) # Unless it is the most recent simulation, decrease the opacity to # 25% and remove it from the legend if sim < num_simulations - 1: plot.figure['data'][-1]['line'].color = 'rgba(255,0,0,0.25)' plot.figure['data'][-1]['showlegend'] = False return plot.figure
[docs] def update_simulation( self, new_init_cond, new_r0, new_r1, new_t1, new_epsilon): """Run a simulation of the branchpro model at the given slider values. Parameters ---------- new_init_cond (int) updated position on the slider for the number of initial cases for the Branch Pro model in the simulator. new_r0 (float) updated position on the slider for the initial reproduction number for the Branch Pro model in the simulator. new_r1 (float) updated position on the slider for the second reproduction number for the Branch Pro model in the simulator. new_t1 (float) updated position on the slider for the time change in reproduction numbers for the Branch Pro model in the simulator. new_epsilon (float) updated position on the slider for the constant of proportionality between local and imported cases for the Branch Pro model in the posterior. Returns ------- pandas.DataFrame Simulations storage dataframe """ data = self.session_data.get('data_storage') serial_interval = self.session_data.get( 'interval_storage').iloc[:, 0].values if data is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = (data.columns[0], 'Incidence Number') times = data[time_label] # Make a new dataframe to save the simulation result simulations = data[[time_label]] # Add the correct R profile to the branchpro model if 'Imported Cases' in data.columns: br_pro_model = bp.LocImpBranchProModel( new_r0, serial_interval, new_epsilon) br_pro_model.set_imported_cases( times, data.loc[:, ['Imported Cases']].squeeze().tolist()) else: br_pro_model = bp.BranchProModel(new_r0, serial_interval) br_pro_model.set_r_profile([new_r0, new_r1], [0, new_t1]) # Generate one simulation trajectory from this model simulation_controller = bp.SimulationController( br_pro_model, min(times), max(times)) try: sim_data = simulation_controller.run(new_init_cond) except ValueError: sim_data = -np.ones(max(times)) # Add data to simulations storage sim_times = simulation_controller.get_regime() simulations = pd.DataFrame({ time_label: sim_times, inc_label: sim_data}) return simulations