RateFlowODE combines neural networks with stoichiometric principles to discover reaction mechanisms from experimental time-series data. Unlike traditional Neural ODEs that learn species dynamics directly, RateFlowODE learns individual reaction rates and combines them through a stoichiometric matrix, enabling the discovery of biochemical reaction networks without prior mechanistic knowledge.

Understanding the RateFlowODE Architecture

Stoichiometric Decomposition of Biochemical Systems

The fundamental innovation of RateFlowODE lies in its decomposition of species dynamics into individual reaction contributions. Rather than learning the net rate of change for each species directly, the approach models the system as: dydt=Sr(y,t,θ)\frac{dy}{dt} = S \cdot r(y, t, \theta) where:
  • yy represents species concentrations
  • SS is the stoichiometric matrix (nspecies×nreactionsn_{species} \times n_{reactions})
  • r(y,t,θ)r(y, t, \theta) are individual reaction rates predicted by the neural network
This mathematical structure enforces biochemical realism by explicitly representing the relationship between individual reactions and their collective effect on species concentrations.

Neural Network Rate Prediction

The neural network component predicts reaction rates based on current species concentrations and time: r(y,t,θ)=ReLU(MLP(y,t,θ))r(y, t, \theta) = \text{ReLU}(\text{MLP}(y, t, \theta)) The ReLU activation ensures that reaction rates remain non-negative, consistent with the physical interpretation of reaction rates as positive quantities. The MLP (Multi-Layer Perceptron) learns the complex concentration dependencies that govern reaction kinetics.

Stoichiometric Matrix Learning

The stoichiometric matrix SS can be handled in three distinct modes:
  1. Fully learnable: SS is initialized randomly and optimized during training
  2. Fixed structure: SS is provided based on known reaction mechanisms
  3. Constrained learning: SS starts from a known structure but can be refined
This flexibility enables applications ranging from complete reaction discovery to refinement of existing mechanistic models.

Core Implementation and Usage

Basic RateFlowODE Construction

Create a RateFlowODE instance for reaction network discovery:
import catalax as ctx
import catalax.neural as ctn
import jax.random as jrandom

# Define system structure (species must be known)
species_order = ["S", "E", "P", "ES"]  # Substrate, Enzyme, Product, Complex
observable_indices = [0, 2]  # Only S and P are measurable

# Create RateFlowODE with learnable stoichiometry
key = jrandom.PRNGKey(42)
rateflow_ode = ctn.RateFlowODE(
    data_size=len(species_order),      # Number of species
    reaction_size=3,                   # Number of reactions to discover
    width_size=64,                     # Neural network width
    depth=3,                           # Neural network depth
    species_order=species_order,
    observable_indices=observable_indices,
    learn_stoich=True,                 # Enable stoichiometry learning
    activation=jax.nn.softplus,        # Smooth activation for rates
    key=key
)

Supplied Stoichiometric Matrix

When reaction mechanisms are partially known, provide the stoichiometric structure:
import jax.numpy as jnp

# Define known reaction mechanism for Michaelis-Menten kinetics
# Reactions: E + S ⇌ ES, ES → E + P
stoich_matrix = jnp.array([
    [-1,  0],  # S: consumed in reaction 1, not involved in reaction 2
    [-1,  1],  # E: consumed in reaction 1, produced in reaction 2  
    [ 1, -1],  # ES: produced in reaction 1, consumed in reaction 2
    [ 0,  1]   # P: not involved in reaction 1, produced in reaction 2
])

# Create RateFlowODE with fixed stoichiometry
rateflow_ode_fixed = ctn.RateFlowODE(
    data_size=4,
    reaction_size=2,
    width_size=32,
    depth=2,
    species_order=species_order,
    observable_indices=observable_indices,
    learn_stoich=False,               # Fix stoichiometry
    stoich_matrix=stoich_matrix,      # Provide known structure
    key=key
)

Mass Conservation Constraints

Mass conservation represents a fundamental constraint in biochemical systems where the total amount of certain molecular species remains constant throughout the reaction process. This is particularly important for enzyme systems where the total enzyme concentration should remain unchanged, or in metabolic pathways where specific atomic groups are conserved. The mass constraint is mathematically represented as: My(t)=c\mathbf{M} \cdot \mathbf{y}(t) = \mathbf{c} where M\mathbf{M} is the mass constraint matrix, y(t)\mathbf{y}(t) is the vector of species concentrations, and c\mathbf{c} is the vector of conserved quantities. For enzyme conservation in Michaelis-Menten kinetics, this ensures that E(t)+ES(t)=EtotalE(t) + ES(t) = E_{total} at all times during the reaction.
# Define conservation constraints
# Example: Total enzyme conservation (E + ES = constant)
mass_constraint = jnp.array([
    [0, 1, 1, 0]  # E + ES conservation constraint
])

# Create RateFlowODE with conservation constraints
rateflow_ode_conserved = ctn.RateFlowODE(
    data_size=4,
    reaction_size=3,
    width_size=64,
    depth=3,
    species_order=species_order,
    observable_indices=observable_indices,
    learn_stoich=True,
    mass_constraint=mass_constraint,  # Enforce conservation
    key=key
)

Analysis and Visualization

Learned Reaction Visualization

Analyze the discovered reaction network using built-in visualization tools:
# Visualize learned reactions and stoichiometry
fig = trained_rateflow.plot_learned_rates(
    dataset=experimental_data,
    model=original_model,  # For species labels
    show=True,
    save_path="./figures/discovered_reactions.png",
    round_stoich=True      # Round coefficients for clarity
)
The visualization provides three panels:
  • Stoichiometric matrix heatmap: Shows discovered reaction coefficients
  • Reaction rates over time: Displays individual reaction dynamics
  • Model fit comparison: Compares predictions with experimental data

Rate Dependency Analysis

Examine how reaction rates depend on species concentrations:
# Create phase plots showing rate dependencies
rate_grid_fig = trained_rateflow.plot_rate_grid(
    dataset=experimental_data,
    model=original_model,
    rate_indices=[0, 1, 2],        # Analyze first three reactions
    species_pairs=[("S", "E"), ("ES", "P")],  # Focus on key species pairs
    representative_time=0.0,        # Time point for analysis
    grid_resolution=50,            # Resolution of concentration grid
    figsize_per_subplot=(6, 5),
    range_extension=0.3,           # Extend beyond data range
    show=True
)