( work in progress)
This repo consists of smaller projects that I am working on in order to learn JAX and its ecosystem. The general struture is that each folder is a self-contained project on it's own but they may be inter connected in some cases.
Currently, this includes:
-
optimizers -- Implementaion of the Muon optimizer in Optax. My implementation and understanding follows from these great write ups by the authors/researchers associated with it: Laker Newhouse, Jeremy Bernstein, and Keller Jordan.
-
nano_gpt -- I thought of implementing the character level nano-gpt codebase by Karpathy (see the video and the original PyTorch code) in JAX without using any of its libraries and mostly with jnp arrays. This turned out to be a great way to understand many core philosophies of JAX. Some interesting reads here: Neel Gupta, Kidger
-
mnist -- A simple CNN on MNIST which compares the same model in PyTorch and JAX, mostly uses higher-level libraries like Flax.
Adam - lr: 2e-3 Muon - lr: 2e-3, polynomial: quintic, beta: 0.95
-
Generated examples:
AMPSON: The senators, father more, stay may than may to good sap my trumpets of heart than that undoved that your loves wrongs a gless runns of a gommort-chreation and little duke's toon quarrel me, my nurse sounds a never not must: but He's wish senself to be so; and receive. If I know! SICINUS: I would not spoke my lord, I sold blood that love comprest not my name; And my convey with a much sensely tongue When sister more either greats my father death, When rue moved thankful envirate more ad
Adam - lr: 2e-3
-
Generated examples:
OFRCUTOLYCUS:LORD: I wash lie; if I do caTse heart him in excle. KING RICHESRY: You hath blood men, I let me with his sea. DUKE VINCENTIO: More say, him sight, and him saidst the lean service And which in his subject. O this sock'd he say, hence you enter'd Him and pity cimpassion-in him in him. AAUTOLYCUS: You have she but them by other black'd oppears to expless' chavise And and Georgest in place the sovereign Second than partlyso more and more harms: what thou he'll'd they crouds, I'll not
(on same gpu)
| Run Number | PyTorch (secs) | JAX (secs) |
|------------|---------------|------------|
| Run 1 | 201.90 | 91.99 |
| Run 2 | 176.23 | 87.59 |
| Run 3 | 170.58 | 87.85 |
| Run 4 | 171.14 | 87.11 |
| Run 5 | 169.91 | 87.39

