JAX Roll Implementation of Finite Differences

See how jnp.roll implements the finite difference stencil for the wave equation. Watch how arrays are shifted and boundary conditions are applied.

Wave Equation Implementation

# Central difference scheme using jnp.roll
u1p = jnp.roll(u1, 1).at[0].set(0) # u[j-1]
u1n = jnp.roll(u1, -1).at[n-1].set(0) # u[j+1]
u2 = 2 * u1 - u0 + C2 * (u1p - 2 * u1 + u1n)

Key Insight: Instead of explicit indexing, we use jnp.roll to shift the entire array, then apply boundary conditions using .at[].set()

Finite Difference Stencil:
∂²u/∂x² ≈ (u[j-1] - 2u[j] + u[j+1]) / Δx²
1. Original u1
2. Roll Left (+1)
3. Roll Right (-1)
4. Finite Difference
Original Array: u1

Current wave state

u1p = jnp.roll(u1, +1)

Left neighbors (boundary at [0] = 0)

u1n = jnp.roll(u1, -1)

Right neighbors (boundary at [n-1] = 0)

Finite Difference Result
u1p[j]
-
2×u1[j]
+
u1n[j]

Second derivative approximation

 
©  |   Cornell University    |   Center for Advanced Computing    |   Copyright Statement    |   Access Statement
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)