The penalty system in Catalax provides sophisticated regularization mechanisms for enforcing biological constraints and improving neural ODE training. Rather than relying solely on data fitting, penalties enable the incorporation of biochemical knowledge, conservation laws, and structural principles into the learning process. This approach ensures that discovered models remain biologically plausible while achieving excellent predictive performance.
Neural networks excel at pattern recognition but can learn solutions that violate fundamental biochemical principles. The penalty framework addresses this challenge by adding constraint terms to the training objective:Ltotal=Ldata+i∑αi⋅Pi(model)where:
Ldata is the standard data fitting loss
Pi are individual penalty functions
αi are penalty strength coefficients
This mathematical structure enables the integration of domain knowledge with data-driven learning.
The penalty system is designed around two core components:Individual Penalty Functions: Each penalty targets a specific biological or mathematical constraint (mass conservation, sparsity, smoothness)Penalty Collections: The Penalties class manages multiple penalty functions, enabling complex constraint combinations and adaptive penalty scheduling
Copy
Ask AI
from catalax.neural.penalties import Penalties, Penalty# Create individual penaltymass_penalty = Penalty( name="mass_conservation", fun=penalize_non_conservative, alpha=0.1)# Create penalty collectionpenalties = Penalties([mass_penalty])# Apply to model during trainingpenalty_value = penalties(neural_model)
Track individual penalty contributions during training:
Copy
Ask AI
def monitor_penalties(model, penalties): """Monitor individual penalty contributions for training diagnostics.""" penalty_values = {} total_penalty = 0 for penalty in penalties.penalties: value = penalty(model) penalty_values[penalty.name] = float(value) total_penalty += value penalty_values["total"] = float(total_penalty) return penalty_values# Use during training monitoringif step % 100 == 0: # Every 100 steps penalty_breakdown = monitor_penalties(current_model, penalties) print("Penalty contributions:") for name, value in penalty_breakdown.items(): print(f" {name}: {value:.6f}")