Add cache-optimized embedding ops (~12x lookup speedup)#39
Open
dev-erik wants to merge 1 commit intomaderix:mainfrom
Open
Add cache-optimized embedding ops (~12x lookup speedup)#39dev-erik wants to merge 1 commit intomaderix:mainfrom
dev-erik wants to merge 1 commit intomaderix:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Drop-in replacement for
embed_lookupandembed_backwardthat eliminates stride-seq cache misses by using contiguousmemcpygather +vDSP_mtranstranspose.Before / After
Benchmarked against upstream
stories_cpu_ops.hon Apple M4 Max, compiled withclang -O2. Stories110M config: dim=768, seq=256, vocab=32000. 500 iterations, 10 warmup.embed_lookupembed_backwardConsistent across 3 consecutive runs (11.5x-12.0x for lookup, 1.1x for backward).
Why it's faster
The original
embed_lookupwritesx[d*seq + t]in a double loop -- every write strides byseqfloats (1 KB at seq=256), causing an L1 cache miss per element. The optimized version:memcpyinto a temp buffervDSP_mtransSame approach for backward: transpose
dxfirst, then scatter-add contiguous rows withvDSP_vadd.Correctness
Bit-exact match (max |diff| = 0.00e+00) with upstream functions. Bounds checks preserved.
Usage
Requires a caller-provided scratch buffer of
seq * dimfloats. No new dependencies (uses AcceleratevDSP_mtrans/vDSP_vadd, already linked).