Skip to content

Get JAX default device from config#119

Open
theo-brown wants to merge 1 commit intowilson-labs:mainfrom
theo-brown:jax-default-device
Open

Get JAX default device from config#119
theo-brown wants to merge 1 commit intowilson-labs:mainfrom
theo-brown:jax-default-device

Conversation

@theo-brown
Copy link

Previously, if on a multi-GPU system, get_default_device() would always trigger AssertionError: array found on more than one device".
Instead, the default device should be retrieved from the jax config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant