From PyTorch to JAX: A Quick Review
We introduced automatic differentiation in the MLP module.
As we saw previously, frameworks like PyTorch keep track of all operations on tensors. When we call .backward() on a final scalar output (like a loss), PyTorch uses reverse-mode AD (backpropagation) to compute the gradient of that output with respect to the inputs that have requires_grad=True. Here's a quick recap of computing gradients in PyTorch:
©
|
Cornell University
|
Center for Advanced Computing
|
Copyright Statement
|
Access Statement
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)