def compute_grad(V, graph, grad_table):
if V in grad_table:
return grad_table[V]
= 0
grad_sum for child, operation in graph[V]:
= compute_grad(child, graph, grad_table)
grad_input += operation.bprop(child, grad_input)
grad_sum
= grad_sum
grad_table[V] return grad_sum
Backpropagation Example: A Step-by-Step Guide
Understanding Gradient Computation in Computational Graphs
Introduction to Backpropagation
Backpropagation is a key algorithm for training neural networks and computing gradients in computational graphs. This guide walks through the gradient computation of a simple problem using a computational graph and backpropagation.
Partial Differential Calculus Rules
Addition Rule:
For a sum of variables, used for \(u^{(3)}=u^{(1)} + u^{(2)}\) below.
\[\frac{\partial (x + y)}{\partial x} = 1, \quad \frac{\partial (x + y)}{\partial y} = 1\]
Power Rule:
For a variable raised to a power, used for \(u^{(4)} = u^{(3)^2}\) with \(n=2\) below.
\[\frac{\partial (x^n)}{\partial x} = n \cdot x^{n-1}\]
Sine Rule:
For the sine of a variable, used for \(u^{(5)} = \sin(u^{(4)})\) below.
\[\frac{\partial (\sin(x))}{\partial x} = \cos(x)\]
Chain Rule:
For a composition of functions, used to propagate gradients backward through the computational graph.
\[\frac{\partial f(g(x))}{\partial x} = \frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial x}\]
Problem Overview
Objective: Compute the gradient \(\frac{\partial u^{(5)}}{\partial u^{(1)}}\) using back propagation:
Computational Graph:
Nodes: \(u^{(1)}, u^{(2)}, u^{(3)}, u^{(4)}, u^{(5)}\)
Operations:
\(u^{(3)}=u^{(1)} + u^{(2)}\)
\(u^{(4)} = u^{(3)^2}\)
\(u^{(5)} = \sin(u^{(4)})\)
Computational Graph Diagram
- Additive Components:
- \(u^{(3)} = u^{(1)} + u^{(2)}\): This operation is purely additive, combining the values of \(u^{(1)}\) and \(u^{(2)}\).
- Multiplicative Components:
\(u^{(4)} = u^{(3)^2}\): This operation introduces a multiplication by squaring the value of \(u^{(3)}\), which can be interpreted as \(u^{(3)} \cdot u^{(3)}\).
During backpropagation, the gradient computation for \(u^{(4)}\) involves multiplying \(2\) with \(u^{(3)}\) as per the derivative rule \(\frac{\partial u^{(4)}}{\partial u^{(3)}} = 2 \cdot u^{(3)}\).
- Combination During Backpropagation:
- The gradients for \(u^{(4)}\) and \(u^{(3)}\) are combined multiplicatively with their respective upstream gradients as the backpropagation traverses the graph.
Defining Operations and Gradients
Operation 1: \(u^{(3)} = u^{(1)} + u^{(2)}\)
Forward: \(u^{(3)} = u^{(1)} + u^{(2)}\)
Gradient:
\(\frac{\partial u^{(3)}}{\partial u^{(1)}} = 1\)
\(\frac{\partial u^{(3)}}{\partial u^{(2)}} = 1\)
Operation 2: \(u^{(4)} = u^{(3)^2}\)
Forward: \(u^{(4)} = u^{(3)^2}\)
Gradient:
- \(\frac{\partial u^{(4)}}{\partial u^{(3)}} = 2 \cdot u^{(3)}\)
Operation 3: \(u^{(5)} = \sin(u^{(4)})\)
Forward: \(u^{(5)} = \sin(u^{(4)})\)
Gradient:
- \(\frac{\partial u^{(5)}}{\partial u^{(4)}} = \cos(u^{(4)})\)
Step-by-Step Gradient Composition
Step 1: Initialization
Start with the forward pass:
\(u^{(1)} = 4\)
\(u^{(2)} = 2\)
\(u^{(3)} = u^{(1)} + u^{(2)} = 6\)
\(u^{(4)} = u^{(3)^2} = 36\)
\(u^{(5)} = \sin(u^{(4)}) = \sin(36)\)
Initialize the seed gradient:
grad_table
\([u^{(5)}] = 1\)
Step 2: Backpropagation
At \(u^{(5)}\):
grad_table
\([u^{(4)}] = \cos(u^{(4)}) \cdot\)grad_table
\([u^{(5)}]\)At \(u^{(4)}\):
grad_table
\([u^{(3)}] = (2 \cdot u^{(3)}) \cdot\)grad_table
\([u^{(4)}]\)At \(u^{(3)}\):
grad_table
\([u^{(1)}] = 1 \cdot\) grad_table
\([u^{(3)}]\)
Final Gradient Computation
Results
grad_table
\([u^{(5)}] = 1\)grad_table
\([u^{(4)}] = \cos(36) \cdot 1\)grad_table
\([u^{(3)}] = (2 \cdot 6) \cdot \cos(36)\)grad_table
\([u^{(1)}] = 1 \cdot (12 \cdot \cos(36))\)
Final Answer:
\(\frac{\partial u^{(5)}}{\partial u^{(1)}} = 12 \cdot \cos(36)\)
Complete Calculation:
\(V\) | \(C\) | \(\text{op}\) | \(\frac{dC}{dV}\) | \(\text{inputs}\) | \(\frac{dC}{dV} (\text{inputs})\) | \(G\) |
---|---|---|---|---|---|---|
\(u^{(4)}\) | \(u^{(5)}\) | \(\sin(u^{(4)})\) | \(\cos(u^{(4)})\) | \(u^{(4)} = 36\) | \(\cos(36)\) | \(\cos(36)\) |
\(u^{(3)}\) | \(u^{(4)}\) | \(u^{(3)^2}\) | \(2 \cdot u^{(3)}\) | \(u^{(3)} = 6\) | \(2 \cdot 6 = 12 \cdot \cos(36)\) | \(12 \cdot \cos(36)\) |
\(u^{(1)}\) | \(u^{(3)}\) | \(u^{(1)} + u^{(2)}\) | \(1\) | \(u^{(1)} = 4, u^{(2)} = 2\) | \(1 \cdot (12 \cdot \cos(36))\) | \(12 \cdot \cos(36)\) |
\(u^{(1)}\) | \(-\) | \(u^{(1)} \text{ (input)}\) | \(1\) | \(-\) | \(12 \cdot \cos(36)\) | \(12 \cdot \cos(36)\) |
Explanation:
\(V\): The node for which the gradient is computed.
\(C\): The consumer node that depends on \(V\).
\(\text{op}\): The operation at the consumer node.
\(\frac{dC}{dV}\): Local gradient of \(C\) with respect to \(V\).
\(\text{inputs}\): Inputs to the consumer node.
\(\frac{dC}{dV}(\text{inputs})\): Gradient contribution to \(V\) from this consumer.
\(G\): Total gradient for \(V\).
Definition
⬇️
Input: Node \(V\), graph structure, and gradient table.
Process:
If \(V\) has already been computed, return its value.
For each child node of \(V\):
Compute gradient recursively.
Backpropagate the gradient using the operation’s backward function.
Store the computed gradient in
grad_table
\([V]\).
Output: Gradient value for \(V\).