Differentiable Simulation
Krishna Kumar
The University of Texas at Austin, Chishiki-AI
08/2025 (original)
In the previous MLP topic, we introduced Automatic Differentiation (AD) as the core engine that enables the training of neural networks by computing gradients of a loss function with respect to network parameters.
However, the power of AD extends far beyond just neural networks. It allows us to make entire physical simulations differentiable. This paradigm, often called Differentiable Simulation or Differentiable Physics, involves implementing a simulator (e.g., a PDE solver) in a framework that supports AD, such as PyTorch or JAX. By doing so, we can automatically compute the gradient of a final quantity (like a measurement or a loss function) with respect to any initial parameter of the simulation.
This topic demonstrates this powerful concept. We will:
- Briefly recall how gradients are computed in PyTorch.
- Introduce the JAX framework for high-performance differentiable programming.
- Build a differentiable simulator for the 1D acoustic wave equation.
- Use this simulator to solve a challenging inverse problem: Full Waveform Inversion (FWI).
This topic includes materials that are collected together in a Jupyter notebook that can be run to reproduce the results contained in the topic pages. Access to the notebook is described in the Lab page at the end of this topic: Lab: Differentiable Simulation. If you would like to run the code in the notebook as you work through the materials in this topic, consult that Lab page for information on how to proceed.
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)