inference - CDCgov/DynODE GitHub Wiki
This document describes the core inference classes and helper utilities used in the DynODE framework for probabilistic compartmental modeling. These APIs are designed to facilitate model fitting, parameter sampling, and checkpointing of simulation states.
classDiagram
%% Abstract base class
class InferenceProcess {
<<abstract>>
+**numpyro_model**: Callable
+inference_prngkey: Array
+infer(**kwargs)
+get_samples(group_by_chain=False, exclude_deterministic=True)
+to_arviz()
- _inference_complete: bool
- _inferer: Optional[MCMC | SVI]
- _inference_state: Optional[HMCState | SVIRunResult]
- _inferer_kwargs: Optional[dict]
}
class MCMCProcess {
+num_samples: int
+num_warmup: int
+num_chains: int
+nuts_max_tree_depth: int
+nuts_init_strategy: Callable
+nuts_kwargs: dict
}
class SVIProcess {
+num_iterations: int
+num_samples: int
+guide_class: Type[AutoContinuous]
+guide_init_strategy: Callable
+optimizer: _NumPyroOptim
+progress_bar: bool
+guide_kwargs: dict
}
%% Inheritance
InferenceProcess --> MCMCProcess : subclass
InferenceProcess --> SVIProcess : subclass
Abstract base class for all inference processes in DynODE.
Defines the interface for fitting a numpyro_model
to data, retrieving posterior samples, and exporting results to ArviZ for diagnostics and visualization.
Key Methods:
-
infer(**kwargs)
: Abstract. Fit the model to data. -
get_samples(group_by_chain=False, exclude_deterministic=True)
: Abstract. Retrieve posterior samples. -
to_arviz()
: Abstract. Convert results to anarviz.InferenceData
object.
Implements inference using Markov Chain Monte Carlo (MCMC) with the NUTS sampler from NumPyro.
Parameters:
-
num_samples
,num_warmup
,num_chains
: Control MCMC sampling. -
nuts_max_tree_depth
,nuts_init_strategy
,nuts_kwargs
: NUTS sampler configuration. -
progress_bar
: Show progress during sampling.
Key Methods:
-
infer(**kwargs)
: Runs MCMC and stores the sampler state. -
get_samples(group_by_chain=False, exclude_deterministic=True)
: Returns posterior samples, optionally grouped by chain and/or including deterministic sites. -
to_arviz()
: Returns anarviz.InferenceData
object with posterior, prior, and posterior predictive samples.
Implements inference using Stochastic Variational Inference (SVI) with NumPyro's autoguides.
Parameters:
-
num_iterations
,num_samples
: Control SVI fitting and posterior sampling respectively. -
guide_class
,guide_init_strategy
,guide_kwargs
: Guide configuration. -
optimizer
: SVI optimizer (default: Adam). -
progress_bar
: Show progress during fitting.
Key Methods:
-
infer(**kwargs)
: Runs SVI and stores the optimizer state. -
get_samples(exclude_deterministic=True)
: Returns posterior samples from the variational guide. No chains are used in SVI, sogroup_by_chain
is not applicable. -
to_arviz()
: Returns anarviz.InferenceData
object with prior, posterior predictive, and log-likelihood.
- For information on exactly what to put inside of
numpyro_model
, please refer to the library backend documentation, section on NumPyro. As numpyro sites are the primary mechanism for the solver/optimizer of each inference process to update and sample parameters. - in the event that your sampler/optimzer
sample_distributions(obj, rng_key=None, _prefix="")
Recursively traverses a data structure, sampling any numpyro.Distribution
objects found.
- Handles nested dicts, lists, and Pydantic models.
- Site names are constructed using the
_prefix
argument for traceability.
Returns:
A copy of obj
with all distributions replaced by samples.
resolve_deterministic(obj, root_params, _prefix="")
Recursively resolves any DeterministicParameter
objects in a data structure, replacing them with their computed values based on root_params
.
Returns:
A copy of obj
with all deterministic parameters resolved.
sample_then_resolve(parameters, rng_key=None)
Convenience function that:
- Deep-copies
parameters
so that parallel chains of inference do not interfere with each other. - Samples all distributions
- Resolves all deterministic parameters
Returns:
A fully concrete, JAX-compatible copy of parameters
.
checkpoint_compartment_sizes(config, solution, save_final_timesteps=True, compartment_save_dates=[])
Records compartment sizes at specified simulation dates for debugging and analysis.
Parameters:
-
config
: TheSimulationConfig
used for the ODE simulation. -
solution
: Thediffrax.Solution
object from ODE integration. -
save_final_timesteps
: IfTrue
, saves the final value for each compartment. -
compartment_save_dates
: List ofdatetime.date
objects to checkpoint.
Behavior:
- Uses
numpyro.deterministic
to record compartment values at requested dates and/or at the final timestep.