Bayesian models in numpyro - sparklabnyc/resources GitHub Wiki

numpyro model

This page provides guidance for the models in bayesian-envhealth-models repo. The user should be able to build up models based on these building blocks.

Likelihoods

Binomial

latent_rate = # effects making up logit-transformed rate

with numpyro.plate("N", size=N):
    mu_logit = latent_rate[...]
    numpyro.sample(
        "deaths",
        # no need for a logit transform with numpyro
        # it can cope with inputting the logits directly
        dist.Binomial(total_count=population, logits=mu_logit),
        # this is where the model sees the data - the number of deaths
        obs=deaths,
    )

Poisson

latent_rate = # effects making up log-transformed rate

with numpyro.plate("N", size=N):
    # offset for population
    # equivalent in other PPLs of:
    # log(mu[i]) <- log(n[i]) + lograte[...]
    mu = jnp.exp(jnp.log(population) + latent_rate[...])
    numpyro.sample(
        "deaths",
        dist.Poisson(rate=mu),
        obs=deaths,
    )

Normal (not appropriate for count data, but maybe for continuous environmental variables like temperature

mean = # effects making up (identity-transformed (not)) mean
sigma = numpyro.sample("sigma", dist.HalfNormal(5.0))

with numpyro.plate("N", size=N):
    mu = mean[...]
    numpyro.sample("outcome", dist.Normal(loc=mu, scale=sigma), obs=outcome)

Age effects

We want to enforce similarity over adjacent age groups. One way of doing this is with random walk priors $A_a \sim \mathcal{N}(A_{a-1}, \sigma_A^2)$. In numpyro, these are specified as follows.

# use a plate to create 1d array `age_drift` of size `N_age - 1`
with numpyro.plate("age_groups", size=(N_age-1)):
    age_drift = numpyro.sample("age_drift", dist.Normal(0, sigma_age))
    # a random walk is just a cumulative sum of random effects https://observablehq.com/@observablehq/plot-random-walk
    # we pad the array with a 0 for identifiability
    # so the array now has length `N_age` (`[0, ...]`)
    # and the effect for the first age group is 0
    age_effect = jnp.pad(jnp.cumsum(age_drift, -1), (1, 0))

Spatial effects

In small-area studies, it is common to smooth data using models with explicit spatial dependence, which are designed to give more weight to nearby areas than those further away. There are three main categories for modelling spatial effects. First, we can treat space as a continuous surface using Gaussian processes or splines. Second, we can use areal models, which make use of the spatial neighbourhood structure of the units. Third, we can build models that exploit a nested hierarchy of geographical units, for example between state, county and census tract in the US. Each of these methods rely on assumptions which may make them more or less appropriate in different applications.

Here, we will focus on areal models, which are the most common in disease mapping studies. A more popular prior is the conditional autoregressive (CAR) prior, also known as a Gaussian Markov random field (GMRF). These form a joint distribution where the covariance is usually defined instead in terms of the precision matrix $\mathbf{P} = \pmb{\Sigma}^{-1} = \tau(\mathbf{D} - \rho \mathbf{A}),$ where $\tau$ controls the overall precision of the effects, $\mathbf{A}$ is the spatial adjacency matrix formed by the small areas, $\mathbf{D}$ is a diagonal matrix with entries equal to the number of neighbours for each spatial unit, and the autocorrelation parameter $\rho$ describes the amount of correlation. This can be seen as tuning the degree of spatial dependence, where $\rho = 0$ implies independence between areas, and $\rho = 1$ full dependence. The case with $\rho = 1$ is called the intrinsic conditional autoregressive (ICAR) model.

The ICAR prior is specified as

spatial_effect_raw = numpyro.sample(
    "spatial_effect_raw",
    dist.CAR(
        loc=0.0,
        # effectively ICAR – there are mathematical reasons it cannot be 1.0
        correlation=0.99,
        conditional_precision=1.0,
        # `adj` is a matrix of 1s and 0s specifying the neighbourhood adjacency
        adj_matrix=adj,
        is_sparse=True,
    ),
)
spatial_effect = spatial_scale * spatial_effect_raw

Temporal effects

For annual data, we would consider linear slopes for trends ($\beta t$) and random walk priors (as above for age) for nonliearities. Seasonality is more complicated but there are examples in the numpyro documentation.

Interactions

Space-time interactions (for example, although could be age-time or race-time) could range from fully independent, to each spatial unit having independent temporal patterns, to inseparable space-time variation where interactions borrow strength across neighbouring spatial units and neighbouring time periods. The most common types are fully independent (Type I) and each spatial unit having independent temporal patterns (type II).

A type I interaction is specified as

with age_plate, space_plate:
    age_space_interaction = numpyro.sample("age_space_interaction", dist.Normal(0, sigma_age_space))

A type II interaction where the temporal effects are separate random walks for each age group can be adapted from the random walk implementation above

with age_plate, time_plate:
    # two plates make `age_time_drift` a 2d array with size `(N_age, N_t - 1)`
    age_time_drift = numpyro.sample(
        "age_time_drift", dist.Normal(0, sigma_rw_age_time)
    )
    # we pad the array with a 0 for identifiability
    age_time_effect = jnp.pad(jnp.cumsum(age_time_drift, -1), [(0, 0), (1, 0)])

Categorical and linear covariates

There are two ways to encode categorical variables. Firstly, using random effects. This works the same as the age group above variable, but without the cumulative sum. i.e. just a mean-zero normal

with numpyro.plate("race", size=N_race):
    race_effect = numpyro.sample("race_effect", dist.Normal(0, sigma_race))

Secondly, we can use fixed effects. For a small number of categories, I would propose using fixed effects. This can be incorporated into a model by one-hot encoding the variables and passing in a matrix X. Note, fixed effects are relative, so you need a reference category (e.g. measuring the effect of Black relative to reference category White). The matrix X can include other fixed effects, including continuous variables, where the effect is a linear slope. In the model, it will look like

with numpyro.plate("covariates", size=N_covariates):
    beta_covariates = numpyro.sample("beta_covariates", dist.Normal(0, 1)) # prior for each independent covariate effect
covariate_effects = jnp.dot(X, beta_covariates)

An example: age-time interaction model

Here is an example of a full model with:

  • intercept (first term of random walk over age effect)
  • slope over time
  • random walk over age effect
  • age-specific random walk over time (type II interaction)
  • Binomial likelihood

The code below is annotated to explain how each effect is built up.

def model_age_time_interaction(age_id, time_id, population, deaths):
    N = len(population)            # size of the data
    N_age = len(np.unique(age_id)) # number of age groups
    N_t = len(np.unique(time_id))  # number of time steps

    # plates control the `size` of the effect, and replace `for i in 1:N_age` in other PPLs
    # the argument `dim` helps us keep track of shapes and allows for clever broadcasting
    # the `dim=-2` argument means any `numpyro.sample` statement within `age_plate` will
    # have shape `(N_age, 1)` rather than just `(N_age,)` – there has been an extra dimension
    # created here meaning all age effects are in the 2nd rightmost dim and all time effects
    # are in the rightmost (`-1`) dim.
    # There is more information on shapes in the docs https://pyro.ai/examples/tensor_shapes.html
    # and in this post by Eric Ma https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/
    age_plate = numpyro.plate("age_groups", size=N_age, dim=-2)
    time_plate = numpyro.plate("time", size=(N_t - 1), dim=-1)

    # hyperparameters
    slope = numpyro.sample("slope", dist.Normal(loc=0.0, scale=1.0))
    sigma_rw_age = numpyro.sample("sigma_rw_age", dist.HalfNormal(1.0)) # Half-Normal is a good prior for positive sd effects
    sigma_rw_age_time = numpyro.sample("sigma_rw_age_time", dist.HalfNormal(1.0))

    # slope over time is the same as adding slope at each timestep
    slope_cum = slope * jnp.arange(N_t) # jnp.arange is [0, 1, ..., N_t], so this becomes [0, 1 * slope, ..., N_t * slope] 

    # random walk over age
    with age_plate:
        age_drift_scale = jnp.pad(
            jnp.broadcast_to(sigma_rw_age, N_age - 1),
            (1, 0),
            # pad so first term is the intercept, prior N(0, 10)
            # this is a bit fancy really
            # we could also just have an effect of size `N_age - 1`
            # do the cumsum and then pad the first term with zero for identifiability (as above)
            constant_values=10.0,
        )[:, jnp.newaxis] # manually turn this from shape `(N_age,)` to `(N_age, 1)` using `jnp.newaxis`
        # `numpyro.sample` statement within `age_plate` will have shape `(N_age, 1)`
        age_drift = numpyro.sample("age_drift", dist.Normal(0, age_drift_scale))
        # needs to be applied over the age dimension, i.e. `dim=-2`
        age_effect = jnp.cumsum(age_drift, -2)

    # age-time random walk (type II) interaction
    with age_plate, time_plate:
        # random sample of shape `(N_age, N_t - 1)`
        age_time_drift = numpyro.sample(
            "age_time_drift", dist.Normal(0, sigma_rw_age_time)
        )
        # cumulative sum over the time dimension (`dim=-1`) and then pad the time dimension with 0 for identifiability
        # so shape of effect becomes `(N_age, N_t)`
        age_time_effect = jnp.pad(jnp.cumsum(age_time_drift, -1), [(0, 0), (1, 0)])

    # this is where the shape magic happens
    # `slope_cum` has shape `(N_t,)`
    # `age_effect` has shape `(N_age, 1)`
    # `age_time_effect` has shape `(N_age, N_t)`
    # we add these things with different shapes
    # the age effect with be "broadcasted" (repeated) over each time step
    # i.e. the age effect is the same in each time step
    # the same happens for the `slope_cum` over age groups
    # `latent_rate` has shape `(N_age, N_t)`
    latent_rate = slope_cum + age_effect + age_time_effect

    # likelihood
    with numpyro.plate("N", size=N):
        # this line plucks out the right `latent_rate` according to the dataset
        # for example, if the row in the data was age group 3 and time step 17
        # it would pick out `latent_rate[3, 17]`
        mu_logit = latent_rate[age_id, time_id]
        numpyro.sample(
            "deaths",
            # no need for a logit transform with numpyro
            # it can cope with inputting the logits directly
            dist.Binomial(total_count=population, logits=mu_logit),
            # this is where the model sees the data - the number of deaths
            obs=deaths,
        )

Further reading

Bigger examples

numpyro docs example:

  • Binomial likelihood
  • Global intercept and slope
  • Age group-specific intercepts and slopes (random walk)
  • Two-tier nested hierarchy of random effects over space, intercepts and slopes
  • Temporal non-linear random walk

mortality-statsmodel:

  • Binomial likelihood
  • Global intercept and slope
  • Age group-specific intercepts and slopes (random walk)
  • Spatial intercept and slopes, either three-tier nested hierarchy of random effects or ICAR
  • Age-space IID (type I) interaction
  • Age-time type II interaction

Notes on reparametrisation

The default sampler is NUTS, which is a variant of HMC. NUTS samplers work better in a non-centred parametrisation, i.e. they prefer $\beta = \mu + \sigma \cdot \mathcal{N}(0, 1)$ to the equivalent $\beta \sim \mathcal{N}(\mu, \sigma^2)$. However, we usually write models (and the maths) using the centred parametrisation.

numpyro lets us write models in the centred parametrisation, then add a decorator to the model to tell numpyro to evaluate the relevant parameters as non-centred. In the case below, we are telling the model to use the non-centred parametrisations for the age_drift and age_time_drift parameters.

@numpyro.handlers.reparam(
    config={
        k: LocScaleReparam(0)
        for k in [
            "age_drift",
            "age_time_drift",
        ]
    }
)
def model_age_time_interaction(
    age_id: Int[Array, "data"],
    time_id: Int[Array, "data"],
    population: Int[Array, "data"],
    deaths: Optional[Int[Array, "data"]] = None,
) -> None:

Notes on typing

This is another way to make sure the users pass the correct type of data to the model and helps prevent errors.

Python is a dynamically typed language. Recent versions of python have allowed for static typing in the form of type hints.

In this way, we can convert the basic function definition

def model_age_time_interaction(age_id, time_id, population, deaths=None)

to the equivalent, but strongly typed

@jaxtyped(typechecker=beartype)
def model_age_time_interaction(
    age_id: Int[Array, "data"],
    time_id: Int[Array, "data"],
    population: Int[Array, "data"],
    deaths: Optional[Int[Array, "data"]] = None,
) -> None:

In the first basic model, there is nothing to stop the user passing the arguments age_id="r", population=True, deaths=12.34, which of course makes no sense. Although these are just type hints, by using the decorator @jaxtyped(typechecker=beartype), these types are enforced and the model throws an error if the user passes something faulty. Read more about the beartype project here.

Further, jaxtyping makes sure all the shape and dtype of jax arrays match (the dimension "data" in this case).

Check convergence

Running a model is great, but you need to check the inference has converged. arviz has inbuilt methods for this. The most important thing is to check the r_hat column values in the summary are all below 1.01 (although 1.05 is also a reasonable threshold).

import arviz as az
import xarray as xr

ds = xr.open_dataset("../../output/model_age_time_interaction_samples.nc")
az.summary(ds)

Preprocessing the data

Although csv files are human-readable, when the dataset gets large the most efficient way of holding data are in binary files. Also, when dealing with multiple causes of death, the population column often contains repeat data, and we might only want to hold one version of this.

Below is a code snippet to convert the csv to the required npy binary files for loading into the modelling framework.

import pandas as pd
import numpy as np

# Read the CSV file into a pandas DataFrame
df = pd.read_csv('simulated_deaths.csv')

# convert columns into indicator variables for hierarchical model
df = df.assign(
    year_id = lambda x: x.year.astype("category").cat.codes,
    age_id = lambda x: x.age.astype("category").cat.codes
)

# Save each column as an array using np.save()
for column in df.columns:
    np.save(f'{column}.npy', df[column].values)