Skip to content

Debug MPS#215

Closed
Aedelon wants to merge 1 commit into
Zyphra:mainfrom
Aedelon:debug_mps
Closed

Debug MPS#215
Aedelon wants to merge 1 commit into
Zyphra:mainfrom
Aedelon:debug_mps

Conversation

@Aedelon
Copy link
Copy Markdown

@Aedelon Aedelon commented Apr 14, 2025

EDIT: THIS IS A BAD WORK AROUND. SEE SOLUTION AND DISCUSSION HERE

Debug MPS Compatibility in Two Locations

This PR includes two fixes to ensure compatibility with the MPS backend (Apple Silicon) in PyTorch:

_1. Fix in backbone/_torch.py (Lines 137–138):

if DEFAULT_DEVICE == torch.device("mps"):
    q, k, v = q.cpu(), k.cpu(), v.cpu()

The F.scaled_dot_product_attention function is still not supported on MPS. This patch forces tensors to CPU in that case, avoiding runtime errors._

Switching from MPS to CPU and back to MPS significantly reduces training speed.
This is just a temporary workaround to uncover other issues. Please do not consider this PR for merging.

2. Fix in model.py (Lines 238–243):

if DEFAULT_DEVICE == torch.device("mps"):
    decode_one_token = torch.compile(
        decode_one_token, dynamic=True, backend="aot_eager", disable=cg or disable_torch_compile
    )
else:
    decode_one_token = torch.compile(decode_one_token, dynamic=True, disable=cg or disable_torch_compile)

The default backend (inductor) used by torch.compile is not supported on MPS. This change conditionally switches the backend to aot_eager when the selected device is MPS.

These adjustments allow the model to run properly on Apple devices using the MPS backend.

@Aedelon
Copy link
Copy Markdown
Author

Aedelon commented Apr 15, 2025

==> #190

@Aedelon Aedelon closed this Apr 15, 2025
@Aedelon Aedelon deleted the debug_mps branch April 15, 2025 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant