Hi, I like your blog post and code, it helped me get started with some things on this subject in JAX, so thanks!
I have a comment about the fisher_vp function. It seems to me that two possible two-index tensors one could construct are
- $F(w) =E_{X,Y} [\nabla L(X,Y,w) \nabla^T L(X,Y,w)]$
- $F(w) = \nabla E_{X,Y} [L(X,Y,w)] \nabla^T E_{X,Y}[ L(X,Y,w) ]$
and in the way fisher_vp is used in the empirical or true Fisher step
|
loss, grads = jax.value_and_grad(mean_cross_entropy)(params, batch) |
|
f = lambda w: mean_cross_entropy(w, batch) |
|
fvp = lambda v: fisher_vp(f, params, v) |
|
ngrad, _ = jax.scipy.sparse.linalg.cg(fvp, grads, maxiter=10) # approx solve |
|
ngrad, _ = jax.scipy.sparse.linalg.cg(fvp, grads, maxiter=10) # approx solve |
the second of these is being used: the derivative of the loss function is averaged before being passed to fisher_vp. However the Fisher matrix requires the averaging to occur over the two-index tensors.
Hi, I like your blog post and code, it helped me get started with some things on this subject in JAX, so thanks!
I have a comment about the fisher_vp function. It seems to me that two possible two-index tensors one could construct are
and in the way fisher_vp is used in the empirical or true Fisher step
naturalgradient/natural_grad.py
Lines 148 to 151 in d7d0a12
naturalgradient/natural_grad.py
Line 222 in d7d0a12
the second of these is being used: the derivative of the loss function is averaged before being passed to fisher_vp. However the Fisher matrix requires the averaging to occur over the two-index tensors.