Backpropagation Through Time
Motivation
A recurrent neural network shares parameters across time steps:
\[ h_t = f_\theta(h_{t-1}, x_t), \qquad \hat y_t = g_\theta(h_t), \]
with the same \(\theta\) at every \(t\). To train it by gradient descent we need \(\partial L / \partial \theta\) for the loss \(L = \sum_t \ell(\hat y_t, y_t)\). Backpropagation through time (BPTT) (Werbos 1990) is the algorithm that computes this gradient by unfolding the recurrence into a feedforward graph with \(T\) time steps and running backpropagation on the unfolded graph.
The whole construction is just reverse-mode automatic differentiation applied to the unrolled computation; “BPTT” is the name for the specialization to recurrent networks.
Problem
Given an RNN \(h_t = f_\theta(h_{t-1}, x_t)\), \(\hat{y}_t = g_\theta(h_t)\) with shared parameters \(\theta\) across time, and a loss \(L = \sum_{t=1}^T \ell(\hat{y}_t, y_t)\), compute \(\partial L / \partial \theta\) for use in gradient descent.
Key Ideas
Unroll the recurrence into a feedforward graph
Treat the RNN as a feedforward graph by replicating \(f_\theta\) at each time step:
\[ h_0 \xrightarrow{x_1} h_1 \xrightarrow{x_2} h_2 \xrightarrow{x_3} \cdots \xrightarrow{x_T} h_T. \]
Each \(h_t = f_\theta(h_{t-1}, x_t)\) uses the same \(\theta\). Outputs \(\hat y_t = g_\theta(h_t)\) are produced at each step (or only at the final step, depending on the task), and the loss is summed over time. The unrolled graph is a depth-\(T\) feedforward network — and ordinary backpropagation applies.
The Jacobian product is the source of trouble
Define the backpropagated adjoint \(\delta_t = \partial L / \partial h_t\). The reverse-mode recurrence is
\[ \delta_t = \frac{\partial \ell_t}{\partial h_t} + \delta_{t+1} \frac{\partial h_{t+1}}{\partial h_t}, \]
with \(\delta_{T+1} = 0\). Unrolling gives a product:
\[ \delta_t = \frac{\partial \ell_t}{\partial h_t} + \sum_{s > t} \frac{\partial \ell_s}{\partial h_s} \prod_{r=t+1}^{s} \frac{\partial h_r}{\partial h_{r-1}}. \]
For a vanilla RNN with \(h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b)\), \(\partial h_t / \partial h_{t-1} = \operatorname{diag}(\tanh'(z_t)) W_{hh}\). When \(T\) is large, this product of \(T\) matrices either contracts to zero (vanishing) or blows up (exploding) at a geometric rate determined by the spectral radius of the Jacobians (proof). This is the central optimization difficulty for RNNs and motivates LSTM, GRU, and gradient clipping.
Algorithm
forward(x_1, ..., x_T):
h_0 = initial state
for t = 1, ..., T:
h_t = f_θ(h_{t-1}, x_t)
ŷ_t = g_θ(h_t)
ℓ_t = ℓ(ŷ_t, y_t)
L = Σ_t ℓ_t
backward():
grad_θ = 0
δ = 0
for t = T down to 1:
δ_local = ∂ℓ_t/∂h_t # local loss gradient
δ = δ_local + δ · ∂h_{t+1}/∂h_t # propagate from later steps
grad_θ += δ · ∂h_t/∂θ|_local # accumulate from this step's use
return grad_θ
The forward pass is the RNN executed normally, with all hidden states saved. The backward pass walks time in reverse, propagating \(\delta\) and accumulating \(\partial L / \partial \theta\) by summing the per-step contributions.
Walkthrough
A three-step RNN
Take a scalar RNN with \(h_t = \theta \cdot h_{t-1} + x_t\), \(\hat{y}_t = h_t\), with \(\ell_t = (\hat{y}_t - y_t)^2\). Initial \(h_0 = 0\), inputs \(x_1 = 1, x_2 = 0, x_3 = 0\), targets \(y_1 = 0, y_2 = 0, y_3 = 1\). Take \(\theta = 0.8\).
Forward.
| \(t\) | \(h_t\) | \(\hat{y}_t\) | \(\ell_t\) |
|---|---|---|---|
| 1 | \(0.8 \cdot 0 + 1 = 1\) | \(1\) | \((1-0)^2 = 1\) |
| 2 | \(0.8 \cdot 1 + 0 = 0.8\) | \(0.8\) | \((0.8-0)^2 = 0.64\) |
| 3 | \(0.8 \cdot 0.8 + 0 = 0.64\) | \(0.64\) | \((0.64-1)^2 \approx 0.13\) |
Total loss \(L \approx 1.77\).
Backward. \(\partial h_t / \partial h_{t-1} = \theta = 0.8\) and \(\partial h_t / \partial \theta\big|_\text{local} = h_{t-1}\). Local loss gradients: \(\partial \ell_t / \partial h_t = 2(\hat{y}_t - y_t)\).
| \(t\) | \(\partial \ell_t / \partial h_t\) | \(\delta_t = \partial L / \partial h_t\) | \(\partial h_t / \partial \theta\big|_\text{local}\) | Contribution to \(\partial L / \partial \theta\) |
|---|---|---|---|---|
| 3 | \(2(0.64 - 1) = -0.72\) | \(-0.72\) | \(h_2 = 0.8\) | \(-0.72 \cdot 0.8 = -0.576\) |
| 2 | \(2(0.8 - 0) = 1.6\) | \(1.6 + 0.8 \cdot (-0.72) = 1.024\) | \(h_1 = 1.0\) | \(1.024 \cdot 1.0 = 1.024\) |
| 1 | \(2(1 - 0) = 2\) | \(2 + 0.8 \cdot 1.024 = 2.819\) | \(h_0 = 0\) | \(2.819 \cdot 0 = 0\) |
Total \(\partial L / \partial \theta = -0.576 + 1.024 + 0 = 0.448\). Increasing \(\theta\) would raise the loss — the optimizer would decrease it.
Notice \(\delta_1 = 2.819\) even though the local loss gradient at \(t = 1\) is only \(2\): errors from later steps have propagated backward through \(\partial h / \partial h\) factors of \(0.8\). With a longer sequence or larger \(|\theta|\), these factors multiply geometrically — exploding or vanishing depending on whether \(|\theta|\) exceeds \(1\) or not.
Correctness
BPTT is reverse-mode autodiff applied to the unrolled recurrence — exact up to floating-point precision. Each backward step is the multivariable chain rule applied locally; the parameter-gradient accumulation correctly sums over all the places \(\theta\) appears in the unrolled graph.
In a modern framework, BPTT is simply what loss.backward() does on a graph built dynamically over time steps. The math doesn’t change; only the engineering attention given to memory and gradient pathologies does.
Complexity and Tradeoffs
Time. \(O(T)\) — each step’s forward and backward are constant cost.
Memory. \(O(T \cdot d_h)\) — all hidden states (and inputs and pre-activations) must be stored for the backward pass. This is the dominant cost for long sequences and the main motivation for truncation.
Truncated BPTT
Truncated BPTT chops the sequence into windows of length \(K\) (say \(K = 50\) or \(100\)). Forward pass runs for the full sequence; backward pass only goes back \(K\) steps from the end of each window. Gradients for parameter updates depend only on dependencies within \(K\) steps, so longer-range structure must be learned indirectly through the hidden state passed across windows.
The trade-off:
- \(K\) small: cheap, but cannot learn dependencies longer than \(K\).
- \(K\) large: closer to the true gradient, but memory and compute scale linearly.
Standard practice is to choose \(K\) based on the longest dependency the task plausibly requires. Language modeling typically uses \(K \in [128, 512]\).
Vanishing and exploding gradients
The Jacobian product makes BPTT optimization fragile. Mitigations:
- Architecture. LSTM and GRU gate the state path so \(\partial h_t / \partial h_{t-1}\) is closer to identity in the relevant regime.
- Gradient clipping. Cap \(\|\partial L / \partial \theta\|\) at a fixed norm to prevent explosions from a single bad batch.
- Initialization. Spectral-norm-aware initialization (orthogonal or identity) keeps Jacobians near unit-norm at the start.
When to Use It
| Situation | Approach |
|---|---|
| Train any RNN over a sequence | BPTT — the universal default. |
| Very long sequence, memory-limited | Truncated BPTT with window \(K\). |
| Streaming / online learning over indefinitely long sequence | Truncated BPTT, or online learning algorithms (RTRL — see Variants). |
| Want full true gradient regardless of length | Full BPTT with gradient checkpointing. |
| Long-range dependencies that BPTT can’t capture | LSTM/GRU, attention, or transformer instead. |
| Continuous-time RNN | Adjoint method (Neural ODE) — solve the adjoint ODE backward. |
Variants
- Truncated BPTT. Window-based gradient that ignores dependencies beyond \(K\) steps. Standard for long-sequence training.
- Real-Time Recurrent Learning (RTRL). Online algorithm that maintains \(\partial h_t / \partial \theta\) explicitly, updating it forward in time. Avoids the BPTT memory cost but adds \(O(d_h^3)\) per step — generally too expensive in practice.
- Synthetic gradients / Decoupled Neural Interfaces. Predict the gradient locally instead of waiting for the global backward pass. Approximate but enables pipeline parallelism across time.
- Adjoint method (Neural ODE). For continuous-time models, solve an adjoint ODE backward in time to compute gradients without storing every intermediate state. Trades memory for an extra ODE solve.
- Gradient checkpointing in time. Store only every \(k\)-th hidden state and recompute the rest during backward — same trick as in feedforward backprop, applied along the time axis.