Forward mode "adjoint" #537
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
Nita aside this LGTM! Once you fix things up I'll merge this into a new dev branch.
diffrax/_adjoint.py
Outdated
|
|
||
| class ForwardMode(AbstractAdjoint): | ||
| """Differentiate through a differential equation solve during the forward pass. | ||
| (So it is not really an adjoint - it is a different way of quantifying the |
There was a problem hiding this comment.
I might specify what an adjoint is here, to make this clear.
There was a problem hiding this comment.
Gave it a shot!
… but for diffrax)
…e forward-mode autodiff is mentioned in the other adjoints
f75dd51 to
e455522
Compare
|
Completed the above fixes :) I also added references to While going though the errata raised by the other adjoints I was also wondering if there is a case |
|
LGTM! And merged :) As for |
* add .venv * add code, tests and documentation for ForwardAdjoint * make version of mkdocs-autorefs explicit (patrick-kidger/optimistix#91, but for diffrax) * rename, add documentation, explicate lack of test covarage for unit-input case. * rename import of ForwardMode * fix duplicate * Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints --------- Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>


Here you go! This is the pragmatic solution, without support or test coverage for integer inputs and only a small comment explicating that forward mode is not really an adjoint, even though its diffrax interface is that of an adjoint.
Changes with respect to the last PR:
ForwardModeeverywhereAbstractAdjointtest_adjoint.pyand explain that since JAX does not offer this option, we're not writing our own workaround to test it eitherOn the last point: if I understood this correctly, then supporting this would entail writing a gradient-computation directly from a JVP with custom "unit pytrees". This is somewhat annoying for mixed array and non-array types.
I'm happy to try again if computing gradients with respect to integer elements of a PyTree is an expected use case (maybe arising from composed/layered transformations of a solve) that requires test coverage.
Earlier comments here.
(This is now rebased on main.)