Add MPS support#190
Conversation
|
@12v have you done testing on this branch on an Apple device? Does it speed this up considerably? This is awesome, because inference is insanely slow on mac (I am using an m4 Max and a few seconds of audio takes minutes.) I tried solving this myself, but continued to run into issues with tokenization. |
|
Actually I just cloned your branch and tested it. Still no support for the Hybrid, but that's ok, the transformer is massively faster on MPS. |
|
How much is the quality affected? Is there a way to get the behavior/quality to match the CPU case? What exactly is causing the quality loss? |
|
@ReadyPlayerEmma to be fair (and complete) the quality loss happens on CUDA too. My guess would simply be floating point precision is less precise. |
|
I just saw your PR. I made mine without being aware of your work. In model.py, you can compile with the backend "aot_eager". It is also indicated on the link you shared: LINK
|
Hello, this is an attempt to add MPS support.
There are two issues preventing MPS from working with the transformer backend:
torch.compile doesn't support MPS (see: here)
The solution for this issue is straightforward, just adding another condition check before using torch.compile.
Grouped Query Attention (GQA) only works with CUDA (see: here)
This is more complex. Falling back to Multi-Headed Attention (MHA) on MPS requires the same number of heads for KV as for Q, but the pre-trained weights expect a smaller number of KV heads than Q heads. My current solution for this is to duplicate the KV heads and weights to match the number of Q heads.
Aside from the extra code, the main downside of this approach is that the weights saved from a model on MPS can't be loaded again (whether on MPS or any other backend). Possible paths forward:
Additionally, an alternative to this approach of transformer the weights within the model is to instead transform the pre-trained weights outside of the model before loading.
Notes
The recording generated on MPS doesn't sound as good as the recording generated on CPU.