Issue
local_gradient_for_argument might read a corrupted float value from the successor.
Say we have the computation graph:
-
$h$ has parameters $w$
-
$f$ has parameters $v$
- $h : \mathbb{R}^m \to \mathbb{R}$
- $f : \mathbb{R}^n \to \mathbb{R}$
-
$h$ is the successor of $f$
-
$E$ is the graph
- $\frac{\partial E}{\partial w} = \frac{\partial{E}}{\partial{h}} \frac{\partial{h}}{\partial{w}}$
- $\frac{\partial E}{\partial f} = \frac{\partial E}{\partial h} \cdot \frac{\partial h}{\partial f}$
- $\frac{\partial E}{\partial v} = \frac{\partial E}{\partial f} \cdot \frac{\partial f}{\partial v}$
- when doing backpropagation, the steps will be
-
$h$ computes $\frac{\partial E}{\partial w}$ and caches $\frac{\partial E}{\partial h}$, and $\frac{\partial h}{\partial w}$
-
$h$ updates $w$ to $w'$
-
$f$ computes $\frac{\partial E}{\partial f}$ and $\frac{\partial h}{\partial f}$ is cached
-
$\frac{\partial h}{\partial f}$ is not yet in cache, so $h$ will have to compute it now
-
$\frac{\partial h}{\partial f}$ is computed based on the new parameter $w'$
- This is the problem!
-
$\frac{\partial h}{\partial f}$ is corrupted
-
$\frac{\partial h}{\partial f}$ is in cache now
-
$\frac{\partial E}{\partial f}$ is computed by looking both $\frac{\partial E}{\partial h}$ and $\frac{\partial h}{\partial f}$ in cache
-
$\frac{\partial E}{\partial f}$ is in cache now
-
$f$ computes $\frac{\partial f}{\partial v}$ and caches it
-
$f$ computes $\frac{\partial E}{\partial v}$ with $\frac{\partial f}{\partial v}$ and the corrupted $\frac{\partial h}{\partial f}$
-
$f$ updates $v$ based on the corrupted $\frac{\partial E}{\partial v}$
Solutions
I can come up with two solutions
- compute
local_gradient ($\frac{\partial h}{\partial f}$) at the beginning of do_gradient_descent_step before parameters ($w$) is modified
- the successor $h$ distributes
local_gradient ($\frac{\partial h}{\partial f}$) and global_gradient ($\frac{\partial E}{\partial h}$) to $f$ before parameters ($w$) is modified
PS
Thank you for your book and the sample code so that I could deeper understand the neural network!
Issue
local_gradient_for_argument might read a corrupted float value from the successor.
Say we have the computation graph:
Solutions
I can come up with two solutions
local_gradient(local_gradient(global_gradient(PS
Thank you for your book and the sample code so that I could deeper understand the neural network!