PreModel and PostModel decorators provide mechanisms for customizing MCMC inference workflows in Catalax. These decorators enable implementation of transformations before and after model simulation, handling scenarios such as uncertain initial conditions, parameter transformations, and algebraic observables. The decorators integrate seamlessly with NumPyro’s probabilistic programming framework while using a context mutation approach for simplicity.

Understanding the Decorator-Based Architecture

The Transformation Pipeline

MCMC inference in Catalax follows a structured pipeline where decorated functions insert custom transformations:
  1. Parameter sampling: NumPyro samples parameters from their prior distributions
  2. PreModel transformation: Custom preprocessing via @pre_model decorated functions
  3. Model simulation: Standard ODE integration using transformed inputs
  4. PostModel transformation: Custom postprocessing via @post_model decorated functions
  5. Likelihood evaluation: Comparison of transformed outputs with experimental data

Decorator Pattern and Context Mutation

The @pre_model and @post_model decorators convert simple user functions into protocol-compliant transformations. Instead of complex return value management, these decorators provide mutable context objects that can be modified directly:
  • Simple decorator syntax: @pre_model and @post_model decorators handle protocol conversion
  • Context mutation: Direct modification of ctx.y0s, ctx.theta, ctx.states attributes
  • Type safety: Full IDE support with proper type inference for context attributes
  • NumPyro integration: Seamless use of numpyro.sample() and numpyro.deterministic()
Note: The underlying PreModel and PostModel protocols are used internally by Catalax and are not intended for direct user implementation.

PreModel Decorator: Input and Parameter Transformations

The @pre_model decorator enables custom transformations applied after parameter sampling but before model simulation. Common use cases include parameter space transformations, initial condition inference, and experimental condition modeling.

Basic Usage and Context Mutation

from catalax.mcmc.protocols import pre_model, PreModelContext
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp

@pre_model
def estimate_uncertain_initials(ctx: PreModelContext):
    """Estimate true initial conditions when measurements are uncertain."""
    
    # Sample measurement uncertainty for initial conditions
    y0_uncertainty = numpyro.sample("y0_uncertainty", dist.HalfNormal(0.1))
    
    # Sample true initial conditions around measured values
    true_y0s = numpyro.sample(
        "true_initial_conditions",
        dist.Normal(ctx.y0s, y0_uncertainty * ctx.y0s)
    )
    
    # Update context with inferred initial conditions
    ctx.y0s = numpyro.deterministic(
        "positive_initial_conditions",
        jnp.maximum(true_y0s, 1e-6)
    )
    # No return statement - context is mutated in place
This example demonstrates the core pattern: the decorated function receives a PreModelContext with mutable attributes (ctx.y0s, ctx.theta, ctx.constants, etc.) that can be modified directly. The ctx.shapes attribute provides dimension information for proper broadcasting operations.

Shape Management for Broadcasting

@pre_model
def measurement_specific_parameters(ctx: PreModelContext):
    """Apply measurement-specific parameter modifications using shape information."""
    
    # Access shape information for proper broadcasting
    n_measurements, n_species = ctx.shapes.y0s
    n_parameters = ctx.theta.shape[-1]
    
    # Sample measurement-specific modifiers
    with numpyro.plate("measurements", n_measurements):
        modifiers = numpyro.sample("measurement_modifiers", dist.Normal(1.0, 0.1))
    
    # Apply modifiers to parameters with proper broadcasting
    ctx.theta = ctx.theta * modifiers[:, None]  # Broadcast over parameter dimension
The ctx.shapes object provides essential dimension information (y0s, data, constants, times) enabling proper array operations and broadcasting across measurements.

PostModel Decorator: Output and Observable Transformations

The @post_model decorator enables custom transformations applied after model simulation but before likelihood evaluation. This is essential for converting model states to experimentally measurable quantities when observables don’t directly correspond to individual model species.

Basic Observable Construction

from catalax.mcmc.protocols import post_model, PostModelContext

@post_model
def total_protein_observable(ctx: PostModelContext):
    """Convert individual protein states to total measurable protein concentration."""
    
    # Access simulated states: [time_points, species]
    free_protein = ctx.states[:, 0]
    bound_protein = ctx.states[:, 1]
    
    # Create observable: total protein concentration
    total_protein = numpyro.deterministic(
        "total_protein_concentration",
        free_protein + bound_protein
    )
    
    # Update context with observable (reshape to maintain dimensions)
    ctx.states = total_protein[:, None]  # Shape: [time_points, 1]
This example shows the core pattern: the decorated function receives a PostModelContext with ctx.states containing simulation results, and can modify it to match experimental observables.

Integration with MCMC Workflows

Using Decorators in MCMC Inference

The decorated functions integrate seamlessly with standard MCMC workflows:
# Define transformations using decorators
@pre_model
def handle_uncertain_initials(ctx: PreModelContext):
    uncertainty = numpyro.sample("init_uncertainty", dist.HalfNormal(0.1))
    true_initials = numpyro.sample("true_initials", dist.Normal(ctx.y0s, uncertainty))
    ctx.y0s = jnp.maximum(true_initials, 1e-6)

@post_model
def total_concentration_observable(ctx: PostModelContext):
    total_conc = numpyro.deterministic("total_conc", jnp.sum(ctx.states, axis=1))
    ctx.states = total_conc[:, None]

# Use in MCMC inference
import catalax.mcmc as cmc

hmc = cmc.HMC(num_warmup=1000, num_samples=2000)
results = hmc.run(
    model=model,
    dataset=dataset,
    yerrs=0.05,
    pre_model=handle_uncertain_initials,
    post_model=total_concentration_observable
)

Built-in PreModel Functions

Catalax provides pre-built transformation functions for common scenarios:
from catalax.mcmc.models import estimate_initials
import numpyro.distributions as dist

# Use built-in initial condition estimator
pre_model_func = estimate_initials(y0_sigma_dist=dist.HalfNormal(5.0))

results = hmc.run(
    model=model,
    dataset=dataset,
    yerrs=0.1,
    pre_model=pre_model_func
)

Key Features and Capabilities

Shape Information Access

The ctx.shapes object provides essential dimension information for proper array operations:
  • ctx.shapes.y0s: Initial conditions dimensions (n_measurements, n_species)
  • ctx.shapes.data: Observed data dimensions (n_measurements, n_timepoints, n_observables)
  • ctx.shapes.constants: Constants dimensions (n_measurements, n_constants)
  • ctx.shapes.times: Time points dimensions (n_measurements, n_timepoints)

NumPyro Compatibility

The decorators are fully compatible with NumPyro’s probabilistic programming primitives:
  • Use numpyro.sample() to introduce new random variables
  • Use numpyro.deterministic() to track transformations for model interpretation
  • Use numpyro.plate() for vectorized operations across measurements or species
  • All JAX operations maintain automatic differentiation compatibility

Best Practices

  1. Context mutation: Always modify context attributes (ctx.y0s, ctx.theta, ctx.states) directly rather than returning values
  2. Numerical stability: Include safeguards against division by zero and negative concentrations
  3. Shape consistency: Use ctx.shapes information to ensure proper broadcasting
  4. Meaningful names: Use descriptive names for numpyro.deterministic() variables
  5. JAX operations: Use JAX-compatible operations for automatic differentiation
The PreModel and PostModel decorators provide a flexible framework for handling complex experimental scenarios in Bayesian inference while maintaining mathematical rigor and seamless integration with NumPyro’s probabilistic programming capabilities.