Skip to content

port2077/jax_projects

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Learning Project

( work in progress)

This repo consists of smaller projects that I am working on in order to learn JAX and its ecosystem. The general struture is that each folder is a self-contained project on it's own but they may be inter connected in some cases.

Currently, this includes:

  • optimizers -- Implementaion of the Muon optimizer in Optax. My implementation and understanding follows from these great write ups by the authors/researchers associated with it: Laker Newhouse, Jeremy Bernstein, and Keller Jordan.

  • nano_gpt -- I thought of implementing the character level nano-gpt codebase by Karpathy (see the video and the original PyTorch code) in JAX without using any of its libraries and mostly with jnp arrays. This turned out to be a great way to understand many core philosophies of JAX. Some interesting reads here: Neel Gupta, Kidger

  • mnist -- A simple CNN on MNIST which compares the same model in PyTorch and JAX, mostly uses higher-level libraries like Flax.

Updates

updated the nano-gpt code with muon optimizer

Adam - lr: 2e-3 Muon - lr: 2e-3, polynomial: quintic, beta: 0.95

Training Loss Plot

  • Generated examples:

    AMPSON:
    The senators, father more, stay may than may to good sap my trumpets of heart
    than that undoved that your loves wrongs a gless
    runns of a gommort-chreation and little duke's toon
    quarrel me, my nurse sounds a never not must: but
    He's wish senself to be so; and receive.
    If I know!
    
    SICINUS:
    I would not spoke my lord,
    I sold blood that love comprest not my name;
    And my convey with a much sensely tongue
    When sister more either greats my father death,
    When rue moved thankful envirate more ad
    

nano-gpt train results

Adam - lr: 2e-3

Training Loss Plot

  • Generated examples:

    OFRCUTOLYCUS:LORD:
    I wash lie; if I do caTse heart him in excle.
    
    KING RICHESRY:
    You hath blood men, I let me with his sea.
    
    DUKE VINCENTIO:
    More say, him sight, and him saidst the lean service
    And which in his subject. O this sock'd he say, hence you enter'd
    Him and pity cimpassion-in him in him.
    
    AAUTOLYCUS:
    You have she but them by other black'd oppears to expless' chavise
    And and Georgest in place the sovereign
    Second than partlyso more and more harms: what
    thou he'll'd they crouds, I'll not
    

mnist pytorch vs jax code

(on same gpu)

  | Run Number | PyTorch (secs) | JAX (secs) |
  |------------|---------------|------------|
  | Run 1      | 201.90        | 91.99      |
  | Run 2      | 176.23        | 87.59      |
  | Run 3      | 170.58        | 87.85      |
  | Run 4      | 171.14        | 87.11      |
  | Run 5      | 169.91        | 87.39  

About

trying out JAX for ml

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors