Skip to content

Conversation

@xibinliu
Copy link
Collaborator

@xibinliu xibinliu commented Jan 13, 2026

Description

Fix rngs used for training.
Also fix the unit tests to make them run in different tpu VMs

Tests

Run a training successfully.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/maxtext_utils.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Also fix the unit tests to make them run in different tpu VMs
"""
# Create a mesh shape for a 5D mesh.
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
devices = np.array(jax.devices()[:4]).reshape((4, 1, 1, 1, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you met issue when testing on other tpu topology? Shall we instead do something like

num_devices = jax.device_count()
devices = np.array(jax.devices()).reshape((num_devices, 1, 1, 1, 1))

per_device_batch_size_increment=1.0,
global_rampup_samples=60,
# global_rampup_samples: (rampup increment number) * (Samples for initial 5 steps)
global_rampup_samples=3 * (1 * jax.device_count() * 5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aireenmei could you help take a quick look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants