diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 000000000000..992d11128793 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,48 @@ +name: Report a bug +description: Report triton failing to compile a kernel, or giving incorrect results +labels: ["bug"] + +body: +- type: markdown + attributes: + value: | + #### Disclaimer + The core triton team is small and has very limited capacity. We may not have time to look into your report. + For the best results, please: + - Avoid submitting duplicates. Search through [the existing and past issues](https://github.com/triton-lang/triton/issues?q=is%3Aissue+sort%3Acreated-desc+) first to see if it's been reported previously. + - Check if the issue persists with a build from the latest source. + - Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion. + - If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions. +- type: textarea + attributes: + label: Describe the bug + description: | + Please provide a clear and concise description of what the bug is. + + If relevant, add a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the bug. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did, so include both the kernel and launching code as well as any relevant imports. + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + + Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + placeholder: | + A clear and concise description of what the bug is. + + ```python + # Sample code to reproduce the problem + ``` + + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment details + description: | + Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using. + placeholder: | + Triton: ... + GPU: ... + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000000..9c1ad58162c7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: Community help + url: https://discord.gg/gpumode + about: GPU-mode discord community has a triton channel which is a great resource for help writing/learning triton diff --git a/.github/ISSUE_TEMPLATE/performance.yml b/.github/ISSUE_TEMPLATE/performance.yml new file mode 100644 index 000000000000..33dddabc41ae --- /dev/null +++ b/.github/ISSUE_TEMPLATE/performance.yml @@ -0,0 +1,44 @@ +name: Report a performance issue +description: Report cases where triton is generating sub-optimal (but functionally correct) PTX/LLVM IR +labels: ["performance"] + +body: +- type: markdown + attributes: + value: | + #### Disclaimer + The core triton team is small and has very limited capacity. We may not have time to look into your report. + For the best results, please: + - Avoid submitting duplicates. Search through [the existing and past issues](https://github.com/triton-lang/triton/issues?q=is%3Aissue+sort%3Acreated-desc+) first to see if it's been reported previously. + - Check if the issue persists with a build from the latest source. + - Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion. + - If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions. +- type: textarea + attributes: + label: Describe the issue + description: | + Please provide a clear and concise description of the issue. + + Include a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the issue. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did. + + A reproducer could be a python program that runs a triton kernel and prints out the relevant suboptimal IR, or an IR file with an accompanying triton-opt command. + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + placeholder: | + A clear and concise description of the issue. + + ```python + # Sample code to reproduce the problem + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment details + description: | + Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using. + placeholder: | + Triton: ... + GPU: ... + validations: + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 63d98f2fc084..71ac17d036d9 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,3 +1,4 @@ + +# New contributor declaration - [ ] I am not making a trivial change, such as fixing a typo in a comment. - [ ] I have written a PR description following these diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml new file mode 100644 index 000000000000..805c6b8dc7b0 --- /dev/null +++ b/.github/workflows/build-test.yml @@ -0,0 +1,158 @@ +name: Build and test +run-name: ${{ inputs.run_name }} + +on: + workflow_dispatch: + pull_request: + branches: + - main + # You can name your branch dev-foo to get CI runs. + - 'dev-**' + push: + branches: + - main + +jobs: + pre-commit: + name: Pre-commit checks + runs-on: + - glados + - intel + - x86 + steps: + - name: Print inputs + run: | + echo "${{ toJSON(github.event.inputs) }}" + echo INSTALL_IPEX=${{ env.INSTALL_IPEX }} + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Run pre-commit checks + run: | + pip install --upgrade pre-commit + + # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed + python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true + # If first run of yapf worked and made changes reset the tree to the original state + git reset --hard + + python3 -m pre_commit run --show-diff-on-failure --color=always --all-files --verbose + + build-test: + name: Build and test on ${{ matrix.config.runner }} + runs-on: ${{ matrix.config.runs_on }} + strategy: + matrix: + python: ['3.11'] + config: + - {runner: 'Ubuntu Intel x86', runs_on: ['glados', 'intel', 'x86'], target-os: 'ubuntu', arch: 'x86'} + - {runner: 'MacOS-latest ARM64', runs_on: ['macos-latest'], target-os: 'macos', arch: 'arm64'} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install pip and apt dependencies + env: + RUNNER_TARGET_OS: ${{ matrix.config.target-os }} + run: | + echo "RUNNER_TARGET_OS: ${RUNNER_TARGET_OS}" + python3 -m pip install --upgrade pip + python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit pybind11 + if [[ "${RUNNER_TARGET_OS}" == "ubuntu" ]]; then + sudo apt-get update + sudo apt-get install -y zlib1g-dev g++ + fi + pip install torch==2.1.2 + + - name: Install Triton + run: | + echo "PATH is '$PATH'" + cd python + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run python unit tests for MacOS Arm64 + if: matrix.config.target-os == 'macos' + run: | + export CC=$(which clang) + export TRITON_DISABLE_OPENMP=1 # temporary + export TRITON_CPU_BACKEND=1 + + # Document some versions/flags + echo "xcode-select:"; xcode-select -p + echo "CC: ${CC}" + clang --version + echo "TRITON_DISABLE_OPENMP=${TRITON_DISABLE_OPENMP}" + echo "TRITON_CPU_BACKEND=${TRITON_CPU_BACKEND}" + + # Skip bfloat16 tests for now + # We are generating bfcvt for bfloat16 tests when converting to fp32. + # This is only for Clang15, works OK for Clang16 + # TODO - fix this using driver flags. + python -m pytest -s -n 32 --device cpu \ + python/test/unit/language/test_core.py -m cpu -k "not bfloat16" + python -m pytest -s -n 32 --device cpu \ + python/test/unit/cpu/test_math.py \ + python/test/unit/cpu/test_opt.py \ + python/test/unit/language/test_annotations.py \ + python/test/unit/language/test_block_pointer.py \ + python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_conversions.py \ + python/test/unit/language/test_decorator.py \ + python/test/unit/language/test_pipeliner.py \ + python/test/unit/language/test_random.py \ + python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_autotuner.py \ + python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_cache.py \ + python/test/unit/runtime/test_driver.py \ + python/test/unit/runtime/test_jit.py \ + python/test/unit/runtime/test_launch.py \ + python/test/unit/runtime/test_subproc.py \ + python/test/unit/test_debug_dump.py \ + -k "not bfloat16" + + - name: Run python unit tests for Intel + if: matrix.config.target-os == 'ubuntu' + run: | + python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu + python -m pytest -s -n 32 --device cpu \ + python/test/unit/cpu/test_math.py \ + python/test/unit/cpu/test_opt.py \ + python/test/unit/language/test_annotations.py \ + python/test/unit/language/test_block_pointer.py \ + python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_conversions.py \ + python/test/unit/language/test_decorator.py \ + python/test/unit/language/test_pipeliner.py \ + python/test/unit/language/test_random.py \ + python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_autotuner.py \ + python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_cache.py \ + python/test/unit/runtime/test_driver.py \ + python/test/unit/runtime/test_jit.py \ + python/test/unit/runtime/test_launch.py \ + python/test/unit/runtime/test_subproc.py \ + python/test/unit/test_debug_dump.py + + - name: Run lit tests + run: | + cd python + LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" + if [ ! -d "${LIT_TEST_DIR}" ]; then + echo "Could not find '${LIT_TEST_DIR}'" ; exit -1 + fi + lit -v "${LIT_TEST_DIR}/TritonCPU" diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7ef502ad25dd..6eb8a614930d 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -9,22 +9,27 @@ name: Integration Tests on: workflow_dispatch: - pull_request: - branches-ignore: ['llvm-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] +# Disabled automatic triggers because tests in this workflow fail to run. +# pull_request: +# # You can name your branch dev-foo to get CI runs. +# branches-ignore: ['llvm-**'] +# merge_group: +# branches: [main, 'dev-**'] +# types: [checks_requested] +# push: +# branches: [main] + concurrency: group: ${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: runs-on: ubuntu-latest @@ -39,6 +44,11 @@ jobs: if: github.event_name == 'pull_request' run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 @@ -141,10 +151,6 @@ jobs: - name: Check pre-commit run: | python3 -m pip install --upgrade pre-commit - # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed - python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true - # If first run of yapf worked and made changes reset the tree to the original state - git reset --hard python3 -m pre_commit run --all-files --verbose - name: Print diff of changes if pre-commit failed if: failure() @@ -158,6 +164,8 @@ jobs: strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -203,37 +211,45 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -279,8 +295,16 @@ jobs: ctest -j32 - name: Run Proton tests run: | - cd third_party/proton - python3 -m pytest -s test + cd third_party/proton/test + python3 -m pytest -s . + cd .. + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -290,28 +314,23 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}}) container: - image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4 + image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root steps: - name: Checkout @@ -358,40 +377,46 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - - name: Update PATH - run: | - echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH - - name: Install pip dependencies + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache + - name: Update compiler to clang run: | - python3 -m pip install --upgrade pip - python3 -m pip install lit + export CC=/usr/bin/clang + export CXX=/usr/bin/clang++ - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" - pip uninstall -y triton + pip uninstall -y triton pytorch-triton-rocm cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} run: | rm -rf ~/.triton + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -407,6 +432,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ @@ -425,13 +451,21 @@ jobs: python3 -m pytest -s -n 8 ./test_cast_matmul.py - name: Run Proton tests run: | - cd third_party/proton - python3 -m pytest -s test + cd third_party/proton/test + python3 -m pytest -s . + cd .. - name: Run C++ unittests run: | cd python cd "build/$(ls build | grep -i cmake)" ctest -j32 + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -441,17 +475,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - name: Clean up caches run: | rm -rf ~/.triton/cache @@ -459,10 +486,12 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -471,7 +500,7 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - name: Compute cache keys id: cache-key run: | @@ -512,22 +541,28 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -537,10 +572,9 @@ jobs: python3 -m venv ~/.venv source ~/.venv/bin/activate python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 + python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -549,7 +583,17 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . + - name: CCache Stats + run: ccache --print-stats + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -559,14 +603,7 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index d84ac6f33466..5341b7d1028d 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -8,13 +8,15 @@ name: Integration Tests on: workflow_dispatch: - pull_request: - branches-ignore: ['llvm-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] +# Disabled automatic triggers because tests in this workflow fail to run. +# pull_request: +# # You can name your branch dev-foo to get CI runs. +# branches-ignore: ['llvm-**'] +# merge_group: +# branches: [main, 'dev-**'] +# types: [checks_requested] +# push: +# branches: [main] concurrency: group: ${{ github.ref }} @@ -23,10 +25,12 @@ concurrency: permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: @@ -43,6 +47,12 @@ jobs: run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV + - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 @@ -155,10 +165,6 @@ jobs: - name: Check pre-commit run: | python3 -m pip install --upgrade pre-commit - # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed - python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true - # If first run of yapf worked and made changes reset the tree to the original state - git reset --hard python3 -m pre_commit run --all-files --verbose - name: Print diff of changes if pre-commit failed @@ -178,6 +184,9 @@ jobs: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -229,24 +238,30 @@ jobs: # files over time. - &restore-build-artifacts-step name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - &inspect-cache-directory-step - name: Inspect cache directory + - &inspect-cache-directories-step + name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | @@ -255,16 +270,20 @@ jobs: - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + + - &print-ccache-stats + name: CCache Stats + run: ccache --print-stats - &run-lit-tests-step name: Run lit tests @@ -319,8 +338,11 @@ jobs: - &run-proton-tests-step name: Run Proton tests run: | - cd third_party/proton - python3 -m pytest -s test + cd third_party/proton/test + python3 -m pytest -s . + cd .. + + - *inspect-cache-directories-step # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. @@ -332,19 +354,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - - &inspect-cache-directories-step - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation @@ -353,6 +366,9 @@ jobs: runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} + strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} @@ -360,7 +376,7 @@ jobs: name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}}) container: - image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4 + image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root steps: @@ -372,23 +388,20 @@ jobs: - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step - - - name: Update PATH - run: | - echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH + - *inspect-cache-directories-step - - name: Install pip dependencies + - name: Update compiler to clang run: | - python3 -m pip install --upgrade pip - python3 -m pip install lit + export CC=/usr/bin/clang + export CXX=/usr/bin/clang++ - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" - pip uninstall -y triton + pip uninstall -y triton pytorch-triton-rocm cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build @@ -396,6 +409,7 @@ jobs: run: | rm -rf ~/.triton + - *print-ccache-stats - *run-lit-tests-step - name: Run python tests on HIP @@ -405,6 +419,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ --ignore=language/test_line_info.py \ @@ -425,8 +440,8 @@ jobs: - *run-proton-tests-step - *run-cpp-unittests-step - - *save-build-artifacts-step - *inspect-cache-directories-step + - *save-build-artifacts-step - name: Clean up caches run: | @@ -436,10 +451,14 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -448,12 +467,12 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step + - *inspect-cache-directories-step - name: Update PATH run: | @@ -464,10 +483,9 @@ jobs: python3 -m venv ~/.venv source ~/.venv/bin/activate python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 + python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -476,7 +494,9 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . - - *save-build-artifacts-step + - *print-ccache-stats - *inspect-cache-directories-step + - *save-build-artifacts-step diff --git a/.github/workflows/llvm-build.yml b/.github/workflows/llvm-build.yml index 9cca050e4904..53755ae8ede7 100644 --- a/.github/workflows/llvm-build.yml +++ b/.github/workflows/llvm-build.yml @@ -28,11 +28,11 @@ jobs: config: - {runner: 'Ubuntu 20.04', runs_on: 'ubuntu-20.04', target-os: 'ubuntu', arch: 'x64'} - {runner: 'Ubuntu 20.04 ARM64', runs_on: 'ubuntu-20.04', target-os: 'ubuntu', arch: 'arm64'} + - {runner: 'CentOS 7', runs_on: ['self-hosted', 'CPU'], target-os: 'centos', arch: 'x64'} - {runner: 'AlmaLinux 8', runs_on: ['self-hosted', 'CPU'], target-os: 'almalinux', arch: 'x64'} - - {runner: 'MacOS X64', runs_on: 'macos-12', target-os: 'macos', arch: 'x64'} - - {runner: 'MacOS ARM64', runs_on: 'macos-12', target-os: 'macos', arch: 'arm64'} - # TODO(#2805): add back once the workflow works and runs in comparable time to the other ones - # - {runner: 'Windows Latest', runs_on: 'windows-latest', target-os: 'windows', arch: 'x64'} + - {runner: 'MacOS X64', runs_on: 'macos-13', target-os: 'macos', arch: 'x64'} + - {runner: 'MacOS ARM64', runs_on: 'macos-13', target-os: 'macos', arch: 'arm64'} + - {runner: 'Windows Latest', runs_on: 'windows-latest', target-os: 'windows', arch: 'x64'} steps: @@ -126,7 +126,8 @@ jobs: -DLLVM_BUILD_TOOLS=ON -DLLVM_ENABLE_ASSERTIONS=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON - -DLLVM_ENABLE_PROJECTS="clang;mlir" + -DLLVM_ENABLE_PROJECTS="mlir;llvm" + -DLLVM_ENABLE_DIA_SDK=OFF -DLLVM_INSTALL_UTILS=ON -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" -DLLVM_ENABLE_TERMINFO=OFF @@ -233,15 +234,16 @@ jobs: tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - name: Configure, Build, Test, and Install LLVM (AlmaLinux) - if: matrix.config.target-os == 'almalinux' + + - name: Configure, Build, Test, and Install LLVM (CentOS) + if: matrix.config.target-os == 'centos' run: | # if this step crashes, it can leave behind a stale docker container docker container prune -f docker rmi -f $(docker images -q) docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ - -f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile . + -f llvm-build/.github/workflows/llvm-build/centos.Dockerfile . # Create temporary container to copy cache and installed artifacts. CONTAINER_ID=$(docker create llvm-build) @@ -256,6 +258,31 @@ jobs: docker rm "${CONTAINER_ID}" + - name: Configure, Build, Test, and Install LLVM (AlmaLinux) + if: matrix.config.target-os == 'almalinux' + run: | + # if this step crashes, it can leave behind a stale docker container + docker container prune -f + docker rmi -f $(docker images -q) + + docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ + -f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile . + + # Create temporary container to copy cache and installed artifacts. + CONTAINER_ID=$(docker create llvm-build) + + # We remove the existing directories, otherwise docker cp will + # create a subdirectory inside the existing directory. + rm -rf "${{ env.SCCACHE_DIR }}" "${{ env.llvm_install_dir }}" + + docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}" + tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" + + docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}" + sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}" + + docker rm "${CONTAINER_ID}" + - name: Upload Build Artifacts uses: actions/upload-artifact@v4 with: @@ -273,6 +300,7 @@ jobs: - name: Upload LLVM Artifacts to Azure if: ${{ (github.repository == 'triton-lang/triton') }} + shell: bash -el {0} run: | az storage blob upload --account-name oaitriton --auth-mode login --container-name public --file "${{ env.llvm_install_dir }}.tar.gz" --name "llvm-builds/${{ env.llvm_install_dir }}.tar.gz" --overwrite diff --git a/.github/workflows/llvm-build/centos.Dockerfile b/.github/workflows/llvm-build/centos.Dockerfile new file mode 100644 index 000000000000..670d211df1ee --- /dev/null +++ b/.github/workflows/llvm-build/centos.Dockerfile @@ -0,0 +1,56 @@ +FROM centos:7 +ARG llvm_dir=llvm-project +# Add the cache artifacts and the LLVM source tree to the container +ADD sccache /sccache +ADD "${llvm_dir}" /source/llvm-project +ENV SCCACHE_DIR="/sccache" +ENV SCCACHE_CACHE_SIZE="2G" + +RUN echo -e "[llvmtoolset-build]\nname=LLVM Toolset 13.0 - Build\nbaseurl=https://buildlogs.centos.org/c7-llvm-toolset-13.0.x86_64/\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/llvmtoolset-build.repo + +# Note: This is required patch since CentOS have reached EOL +# otherwise any yum install setp will fail +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + +# Install build dependencies +RUN yum install --assumeyes centos-release-scl + +# The definition of insanity is doing the same thing and expecting a different result +RUN sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo +RUN sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo +RUN sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + +RUN yum install --assumeyes --nogpgcheck llvm-toolset-13.0 +RUN yum install --assumeyes rh-python38-python-devel rh-python38-python-pip +SHELL [ "/usr/bin/scl", "enable", "llvm-toolset-13.0", "rh-python38" ] + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --upgrade cmake ninja sccache + +# Install MLIR's Python Dependencies +RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt + +# Configure, Build, Test, and Install LLVM +RUN cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_ASM_COMPILER=clang \ + -DCMAKE_C_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_FLAGS="-Wno-everything" \ + -DCMAKE_LINKER=lld \ + -DCMAKE_INSTALL_PREFIX="/install" \ + -DLLVM_BUILD_UTILS=ON \ + -DLLVM_BUILD_TOOLS=ON \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \ + /source/llvm-project/llvm + +RUN ninja -C build install diff --git a/.github/workflows/torch-inductor-tests.yml b/.github/workflows/torch-inductor-tests.yml deleted file mode 100644 index 3d8f98095291..000000000000 --- a/.github/workflows/torch-inductor-tests.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Torchinductor - -on: - workflow_run: - workflows: ["Wheels"] - types: [completed] - workflow_dispatch: - -permissions: read-all - -jobs: - Runner-Preparation: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} - steps: - - name: Prepare runner matrix - id: set-matrix - run: | - echo '::set-output name=matrix::[["self-hosted", "A100"]]' - - Torch-Inductor-Tests: - needs: Runner-Preparation - timeout-minutes: 240 # 4 hours - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}} - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Packages - run: | - ./.github/workflows/torch-inductor/scripts/install_torchinductor.sh torchbench - - name: Environment - run: | - source /tmp/torchinductor_venv/bin/activate - ./.github/workflows/torch-inductor/scripts/install_triton.sh - - name: Performance - run: | - ./.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh torchbench - # Runs too long time - #- name: Accuracy - # run: | - # ./.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh torchbench diff --git a/.github/workflows/torch-inductor/scripts/check_acc.py b/.github/workflows/torch-inductor/scripts/check_acc.py deleted file mode 100644 index c89976acab11..000000000000 --- a/.github/workflows/torch-inductor/scripts/check_acc.py +++ /dev/null @@ -1,11 +0,0 @@ -import csv -import sys - -file_path = sys.argv[1] -with open(file_path) as f: - reader = csv.reader(f) - for i, row in enumerate(reader): - if i == 0: - continue - if row[3] != "pass": - print(f"{row[1]} failed on device {row[0]} with batch size {row[2]}") diff --git a/.github/workflows/torch-inductor/scripts/check_perf.py b/.github/workflows/torch-inductor/scripts/check_perf.py deleted file mode 100644 index 212eadad55ae..000000000000 --- a/.github/workflows/torch-inductor/scripts/check_perf.py +++ /dev/null @@ -1,70 +0,0 @@ -import argparse -import csv -from collections import namedtuple - -# Create a named tuple for the output of the benchmark -BenchmarkOutput = namedtuple('BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) - - -def parse_output(file_path: str) -> dict: - entries = {} - with open(file_path) as f: - reader = csv.reader(f) - for i, row in enumerate(reader): - if i == 0 or len(row) < 5: - continue - dev = row[0] - name = row[1] - batch_size = row[2] - speedup = float(row[3]) - latency = float(row[4]) - entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency) - return entries - - -def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool: - baseline_geomean = 1.0 - new_geomean = 1.0 - for key in new: - if key not in baseline: - print(f"New benchmark {key} not found in baseline") - baseline_latency = baseline[key].latency - new_latency = new[key].latency - if baseline_latency == 0: - print(f"Baseline latency for {key} is 0") - continue - elif new_latency == 0: - print(f"New latency for {key} is 0") - continue - - if new_latency < baseline_latency * (1 - threshold): - print(f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") - elif new_latency > baseline_latency * (1 + threshold): - print(f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") - else: - print(f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}") - baseline_geomean *= baseline[key].speedup - new_geomean *= new[key].speedup - - baseline_geomean = baseline_geomean**(1 / len(baseline)) - new_geomean = new_geomean**(1 / len(new)) - print(f"Baseline geomean: {baseline_geomean}") - print(f"New geomean: {new_geomean}") - assert new_geomean >= baseline_geomean * (1 - geomean_threshold), \ - f"New geomean is slower than baseline: {new_geomean} vs {baseline_geomean}" - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--baseline', required=True) - parser.add_argument('--new', required=True) - parser.add_argument('--threshold', type=float, default=0.1) - parser.add_argument('--geomean-threshold', type=float, default=0.02) - args = parser.parse_args() - baseline = parse_output(args.baseline) - new = parse_output(args.new) - compare(baseline, new, args.threshold, args.geomean_threshold) - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/torch-inductor/scripts/common.sh b/.github/workflows/torch-inductor/scripts/common.sh deleted file mode 100755 index 7e212a06a1ba..000000000000 --- a/.github/workflows/torch-inductor/scripts/common.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -TEST_REPORTS_DIR=/tmp/torchinductor_reports -PYTORCH_DIR=/tmp/pytorch -MODELS=(timm_models huggingface torchbench) - -echo "$TEST_REPORTS_DIR" -echo "$PYTORCH_DIR" -echo "${MODELS[@]}" diff --git a/.github/workflows/torch-inductor/scripts/install_torchinductor.sh b/.github/workflows/torch-inductor/scripts/install_torchinductor.sh deleted file mode 100755 index 18bea1f1716f..000000000000 --- a/.github/workflows/torch-inductor/scripts/install_torchinductor.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -MODEL_SPEC=$1 - -# torchinductor venv -whoami - -sudo apt-get update && sudo apt-get install -y python3-venv libgl1 - -# clean up old venv -rm -rf /tmp/torchinductor_venv -python3 -m venv /tmp/torchinductor_venv -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source ./.github/workflows/torch-inductor/scripts/common.sh - -pip3 install --upgrade pip wheel setuptools - -# Install torchtext stable first. Bundling it in the same install as torch -# nightly forces torch stable release to be installed instead. -# From https://github.com/pytorch/text?tab=readme-ov-file#torchtext, -# "WARNING: TorchText development is stopped and the 0.18 release (April 2024) -# will be the last stable release of the library." -pip3 install --force-reinstall torchtext - -# pytorch nightly -pip3 install --force-reinstall --pre torch torchvision torchaudio torchrec --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -# pytorch source to get torchbench for dynamo -cd /tmp || exit -# cleanup old pytorch -rm -rf pytorch -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch || exit -# if you are updating an existing checkout -git submodule sync -git submodule update --init --recursive -cd .. - -# required packages -# https://github.com/pytorch/benchmark/blob/main/docker/gcp-a100-runner-dind.dockerfile#L17 -sudo apt-get install --yes libpango-1.0-0 libpangoft2-1.0-0 -pip3 install expecttest psutil lightning-utilities pyre_extensions - -# torchbench -if [ "$MODEL_SPEC" == "torchbench" ] || [ "$MODEL_SPEC" != "all" ]; then - # clean up old torchbench - rm -rf benchmark - pip3 install pyyaml - git clone https://github.com/pytorch/benchmark.git - cd benchmark || exit - python3 install.py - cd .. -fi - -# timm -if [ "$MODEL_SPEC" == "timm_models" ] || [ "$MODEL_SPEC" != "all" ]; then - # clean up old timm - rm -rf pytorch-image-models - git clone https://github.com/huggingface/pytorch-image-models.git - cd pytorch-image-models || exit - pip3 install -e . - cd .. -fi - -# clean up cache -rm -rf /tmp/torchinductor_"$(whoami)"/ -rm -rf ~/.triton/cache -rm -rf "$TEST_REPORTS_DIR" - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/install_triton.sh b/.github/workflows/torch-inductor/scripts/install_triton.sh deleted file mode 100755 index 43367a02f527..000000000000 --- a/.github/workflows/torch-inductor/scripts/install_triton.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source ./.github/workflows/torch-inductor/scripts/common.sh - -# Triton build-time dependencies -pip3 install --upgrade cmake ninja lit - -# build our own triton and preserve the wheel build for later re-use in this test run. -cd python || exit -pip3 uninstall pytorch-triton -y -rm -rf build dist -python3 setup.py bdist_wheel -pip3 install dist/triton*.whl - -# clean up cache -rm -rf ~/.triton/cache - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh deleted file mode 100755 index aefd798f39ff..000000000000 --- a/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torch-inductor -MODEL_SPEC=$1 - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source "$INDUCTOR"/scripts/common.sh - -# Dependency of 'torch/fx/experimental/validator.py'. -pip3 install --upgrade z3-solver - -# Install our own triton. -pip3 uninstall pytorch-triton -y -cd $ROOT/python || exit -if [ -d "./dist" ]; then - pip3 install dist/triton*.whl -else - rm -rf build - pip3 install -e . -fi - -cd "$PYTORCH_DIR" || exit -TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc -mkdir -p "$TEST_REPORTS_DIR" - -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running accuracy test for $model" - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --inference --device cuda \ - --output "$TEST_REPORTS_DIR"/inference_"$model".csv - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --training --amp --device cuda \ - --output "$TEST_REPORTS_DIR"/training_"$model".csv - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --training --dynamic-shapes --device cuda \ - --output "$TEST_REPORTS_DIR"/dynamic_shapes_"$model".csv -done - -cd "$ROOT" || exit -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Checking accuracy test for $model" - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/inference_"$model".csv - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/training_"$model".csv - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/dynamic_shapes_"$model".csv -done - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh deleted file mode 100755 index 35853d97c8fe..000000000000 --- a/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torch-inductor -MODEL_SPEC=$1 - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source "$INDUCTOR"/scripts/common.sh - -# lock GPU clocks to 1350 MHz -sudo nvidia-smi -i 0 -pm 1 -sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 - -cd "$PYTORCH_DIR" || exit -TRITON_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/perf -BASE_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc -mkdir -p "$TRITON_TEST_REPORTS_DIR" -mkdir -p "$BASE_TEST_REPORTS_DIR" - -# Dependency of 'pytorch/benchmarks/dynamo/common.py'. -pip3 install pandas scipy - -echo "Running with Triton Nightly" -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running performance test for $model" - python3 benchmarks/dynamo/"$model".py --ci --float32 --training --inductor --performance --device cuda \ - --output "$TRITON_TEST_REPORTS_DIR"/"$model".csv -done - -# install pytorch-triton -pip3 uninstall triton -y -pip3 install --pre pytorch-triton --extra-index-url https://download.pytorch.org/whl/nightly/cu121 - -echo "Running with pytorch-triton" -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running performance test for $model" - python3 benchmarks/dynamo/"$model".py --ci --float32 --training --inductor --performance --device cuda \ - --output "$BASE_TEST_REPORTS_DIR"/"$model".csv -done - -# uninstall pytorch-triton -pip3 uninstall pytorch-triton -y - -cd "$ROOT" || exit -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Checking performance test for $model" - python3 "$INDUCTOR"/scripts/check_perf.py --new "$TRITON_TEST_REPORTS_DIR"/"$model".csv --baseline "$BASE_TEST_REPORTS_DIR"/"$model".csv - EXIT_STATUS=$? - if [ "$EXIT_STATUS" -ne 0 ]; then - echo "Performance test for $model failed" - exit "$EXIT_STATUS" - fi -done - -# unlock GPU clocks -sudo nvidia-smi -i 0 -rgc - -# go back to where we started -cd "$ROOT" || exit diff --git a/.gitignore b/.gitignore index 2de228701c87..a78598e107d0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,11 @@ python/triton*.egg-info/ python/triton/_C/*.pyd python/triton/_C/*.so +python/triton/_C/*.so.* python/triton/_C/*.dylib +python/triton/_C/*.pdb +python/triton/_C/*.exe +python/triton/_C/*.ilk # Backends copied from submodules python/triton/backends/ @@ -23,6 +27,9 @@ python/triton/language/extra # Proton python/triton/profiler +# Pytest +pytest.ini + # Instrumentation python/triton/instrumentation diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..b2b6bf04a546 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sleef"] + path = third_party/sleef + url = https://github.com/shibatch/sleef diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2aab636bbde..a85e54d05d18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ +default_stages: [pre-commit, pre-push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -17,12 +18,11 @@ repos: - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.7.1 hooks: - id: ruff files: '^python/.*' - args: ["--fix", "--line-length", "120"] - stages: [pre-commit, pre-push, manual] + args: ["--fix", "--exit-non-zero-on-fix"] exclude: | (?x)( ^python/triton/runtime/.*| @@ -31,18 +31,16 @@ repos: ) - repo: https://github.com/google/yapf - rev: be72557 + rev: "7e21823" hooks: - id: yapf args: ["-p", "-i"] - stages: [pre-commit, pre-push, manual] exclude: "python/test/unit/language/test_line_info.py" - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v19.1.2 hooks: - id: clang-format - stages: [pre-commit, pre-push, manual] # Expand YAML anchors in files used by github workflows, because github can't # do this itself. This lets us use anchors, which avoids code duplication. diff --git a/CMakeLists.txt b/CMakeLists.txt index a892a666d56c..3836abadc867 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,30 +12,59 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_INCLUDE_CURRENT_DIR ON) -project(triton) +project(triton CXX C) include(CTest) -if(NOT WIN32) - list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -endif() - - +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) +option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") +if(TRITON_BUILD_WITH_CCACHE) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" + CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" + CACHE STRING "CXX compiler launcher") + else() + message( + STATUS + "Could not find ccache. Consider installing ccache to speed up compilation." + ) + endif() +endif() + +set(TRITON_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if (TRITON_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${TRITON_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) +endif() + + # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests # Customized release build type with assertions: TritonRelBuildWithAsserts -set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") -set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") +if(NOT MSVC) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") + set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") +else() + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -48,12 +77,19 @@ if(NOT WIN32) endif() if(TRITON_BUILD_UT) + # This is an aggregate target for all unit tests. + add_custom_target(TritonUnitTests) + set_target_properties(TritonUnitTests PROPERTIES FOLDER "Triton/Tests") include(AddTritonUnitTest) endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS") +endif() # ######### @@ -107,7 +143,11 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX- /wd4244 /wd4624 /wd4715 /wd4530") +endif() include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -121,10 +161,6 @@ include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files add_subdirectory(include) add_subdirectory(lib) -# find_package(PythonLibs REQUIRED) -set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") -set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") - # TODO: Figure out which target is sufficient to fix errors; triton is # apparently not enough. Currently set linking libstdc++fs for all targets # to support some old version GCC compilers like 8.3.0. @@ -141,22 +177,9 @@ if(TRITON_BUILD_PYTHON_MODULE) set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) include_directories(${PYTHON_SRC_PATH}) - if(PYTHON_INCLUDE_DIRS) - # We have PYTHON_INCLUDE_DIRS set--this is what we expect when building - # using pip install. - include_directories(${PYTHON_INCLUDE_DIRS}) - include_directories(${PYBIND11_INCLUDE_DIR}) - else() - # Otherwise, we might be building from top CMakeLists.txt directly. - # Try to find Python and pybind11 packages. - find_package(Python3 REQUIRED COMPONENTS Development Interpreter) - find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") - include_directories(${Python3_INCLUDE_DIRS}) - include_directories(${pybind11_INCLUDE_DIR}) - link_directories(${Python3_LIBRARY_DIRS}) - link_libraries(${Python3_LIBRARIES}) - add_link_options(${Python3_LINK_OPTIONS}) - endif() + # Python Interpreter is used to run lit tests + find_package(Python3 REQUIRED COMPONENTS Development.Module Interpreter) + find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") if (DEFINED TRITON_PLUGIN_DIRS) foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS}) @@ -182,6 +205,9 @@ if(TRITON_BUILD_PYTHON_MODULE) if (TRITON_BUILD_PROTON) add_subdirectory(third_party/proton) endif() + # We always build proton dialect + list(APPEND TRITON_PLUGIN_NAMES "proton") + add_subdirectory(third_party/proton/dialect) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) @@ -219,6 +245,9 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMAMDGPUCodeGen LLVMAMDGPUAsmParser + Python3::Module + pybind11::headers + ) if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR # macOS arm64 @@ -227,7 +256,7 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMAArch64CodeGen LLVMAArch64AsmParser ) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") list(APPEND TRITON_LIBRARIES LLVMX86CodeGen LLVMX86AsmParser @@ -259,9 +288,11 @@ if(TRITON_BUILD_PYTHON_MODULE) ${PYTHON_SRC_PATH}/llvm.cc) # Link triton with its dependencies - target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) + target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") else() target_link_libraries(triton PRIVATE z) endif() @@ -280,19 +311,28 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) set(PYTHON_LDFLAGS "-undefined dynamic_lookup") endif() - target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) + target_link_options(triton PRIVATE ${PYTHON_LDFLAGS}) endif() if(NOT TRITON_BUILD_PYTHON_MODULE) foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + add_subdirectory(third_party/proton/dialect) endif() +find_package(Threads REQUIRED) + add_subdirectory(third_party/f2reduce) add_subdirectory(bin) add_subdirectory(test) if(TRITON_BUILD_UT) add_subdirectory(unittest) + # This target runs all the unit tests. + add_custom_target(check-triton-unit-tests + COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure + DEPENDS TritonUnitTests + USES_TERMINAL + ) endif() diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000000..3232ed665566 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/README.md b/README.md index 4685ae30fc4c..8f8f585aada2 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,36 @@ -
- Triton logo -
+# Triton-CPU + +A long-lived development branch to build an experimental CPU backend for [Triton](https://github.com/openai/triton). + +This repository clones the main Triton repository, but we intend to minimize +divergences in the core (and ideally upstream anything that needs to change and +isn't too CPU-specific). Most of the CPU work should be in a backend +subdirectory (similar to how GPU vendors are supported today). We're starting +with a clone to give ourselves maximum development flexibility as this project +gets off the ground! + +# How to use it? + +Build it like a normal Triton, but just pass TRITON_CPU_BACKEND=1 to use the CPU backend over a GPU backend, if any. + +``` +TRITON_CPU_BACKEND=1 python3 tutorials/01-vector-add.py +``` -The Triton Conference is happening again on September 17th, 2024 in Fremont (CA)! +**NOTE: It's still work in progress.** -If you are interested in attending, please fill up [this form](https://docs.google.com/forms/d/e/1FAIpQLSecHC1lkalcm0h3JDUbspekDX5bmBvMxgVTLaK3e-61bzDDbg/viewform). +--- +# Upstream README + +
+ Triton logo +
| **`Documentation`** | **`Nightly Wheels`** | |-------------------- | -------------------- | | [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) | - # Triton This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. @@ -24,20 +43,21 @@ The [official documentation](https://triton-lang.org) contains installation inst You can install the latest stable release of Triton from pip: -```bash +```shell pip install triton ``` + Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9. And the latest nightly release: -```bash +```shell pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly ``` # Install from source -``` +```shell git clone https://github.com/triton-lang/triton.git; cd triton; @@ -47,7 +67,7 @@ pip install -e python Or with a virtualenv: -``` +```shell git clone https://github.com/triton-lang/triton.git; cd triton; @@ -156,14 +176,14 @@ $ lit test You may find it helpful to make a symlink to the builddir and tell your local git to ignore it. -``` +```shell $ ln -s python/build/cmake<...> build $ echo build >> .git/info/exclude ``` Then you can e.g. rebuild and run lit with the following command. -``` +```shell $ ninja -C build && ( cd build ; lit test ) ``` @@ -177,6 +197,9 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only. - Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*` - `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR. +- `TRITON_REPRODUCER_PATH=` will generate an MLIR reproducer file + at `` before each MLIR compiler stage. If any of the stages fail, + `` will be a local MLIR reproducer captured right before the failing pass. - `TRITON_INTERPRET=1` uses the Triton interpreter instead of running on the GPU. You can insert Python breakpoints in your kernel code! - `TRITON_ENABLE_LLVM_DEBUG=1` passes `-debug` to LLVM, printing a lot of @@ -213,10 +236,30 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). - `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. +- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx. +- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx when `TRITON_KERNEL_DUMP` is set to 1. +- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage. +- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx files when `TRITON_KERNEL_OVERRIDE` is set to 1. + +**Kernel Override Steps** + +```bash +export TRITON_ALWAYS_COMPILE=1 +export TRITON_KERNEL_DUMP=1 +export TRITON_DUMP_DIR= +export TRITON_KERNEL_OVERRIDE=1 +export TRITON_OVERRIDE_DIR= +# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR +# Step 2: Copy $TRITON_DUMP_DIR/ to $TRITON_OVERRIDE_DIR +# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override +# Step 4: Run the kernel again to see the overridden result +``` + # Changelog Version 2.0 is out! New features include: + - Many, many bug fixes - Performance improvements - Backend rewritten to use MLIR @@ -226,13 +269,14 @@ Version 2.0 is out! New features include: Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/triton-lang/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md). - # Compatibility Supported Platforms: - * Linux + +- Linux Supported Hardware: - * NVIDIA GPUs (Compute Capability 7.0+) - * AMD GPUs (ROCm 5.2+) - * Under development: CPUs + +- NVIDIA GPUs (Compute Capability 8.0+) +- AMD GPUs (ROCm 5.2+) +- Under development: CPUs diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 25a891c2f7d9..b608057d3fa1 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -2,7 +2,9 @@ #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "amd/include/TritonAMDGPUTransforms/Passes.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -15,12 +17,20 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h" +#include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" +#include "cpu/include/Xsmm/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "cpu/include/TritonRaiseBlockPointer/Passes.h" + #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/InitAllPasses.h" @@ -45,6 +55,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestMembarPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerAllocateSharedMemoryPass(); + mlir::triton::registerTritonGPUGlobalScratchAllocationPass(); mlir::triton::registerConvertTritonGPUToLLVMPass(); mlir::triton::registerConvertNVGPUToLLVMPass(); mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); @@ -60,16 +71,32 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); - mlir::registerTritonAMDGPUStreamPipelineV2(); + mlir::registerTritonAMDGPUStreamPipeline(); mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + + // CPU passes + mlir::triton::cpu::registerTritonToTritonCPUPasses(); + mlir::triton::cpu::registerTritonCPUTransformsPasses(); + mlir::triton::cpu::registerTritonCPUToLLVMPasses(); + mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); + mlir::triton::cpu::registerTritonCPUXsmmPasses(); + + mlir::triton::cpu::registerTritonRaiseBlockPointerPass(); // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 4087ac135022..7c635dafaa3d 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -22,7 +22,7 @@ using namespace mlir; // clang-format off // Example usage: // -// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" // // triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt // @@ -30,8 +30,8 @@ using namespace mlir; // // An input file usually looks like: // ''' -// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> -// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> // ''' // clang-format on @@ -83,7 +83,7 @@ LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); // Dispatch to the corresponding dialect helper function to print the layout. - if (dialectName == "triton_gpu") { + if (dialectName == "ttg") { os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); return success(); } diff --git a/build.sh b/build.sh new file mode 100755 index 000000000000..8a55be8919fa --- /dev/null +++ b/build.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +source ./../miniforge/bin/activate triton + +# Note, build and install LLVM on this directory with hash to avoid conflicts +export LLVM_BUILD_DIR=$PWD/../llvm-project/build +export TRITON_BUILD_WITH_CCACHE=false +export TRITON_BUILD_WITH_CLANG_LLD=true +export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib +export LLVM_SYSPATH=$LLVM_BUILD_DIR + +echo "===================================== Build" +pip install -e python/ +if [ $? != 0 ]; then + exit 1 +fi + +echo "===================================== CMake Tests" +ctest --test-dir python/build/cmake* +if [ $? != 0 ]; then + exit 1 +fi + +echo "===================================== Setting up LIBXSMM for paddeing" +export XSMM_ROOT_DIR=$(realpath $(find python/build/ -type d -name xsmm-src | grep -v third_party)) +export XSMM_LIB_DIR=$(realpath python/triton/_C) +cd third_party/cpu/python +python setup.py install +cd ../../../ + +conda deactivate diff --git a/cmake/AddTritonUnitTest.cmake b/cmake/AddTritonUnitTest.cmake index 24fb20a72b8c..a9efb9ad1ad8 100644 --- a/cmake/AddTritonUnitTest.cmake +++ b/cmake/AddTritonUnitTest.cmake @@ -35,5 +35,8 @@ function(add_triton_ut) # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac # laptop. I think the issue may be that the very first time you run a program # it's a bit slow. - gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60) + gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60) + + # Add the unit test to the top-level unit test target. + add_dependencies(TritonUnitTests ${__NAME}) endfunction() diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index b000a3129912..50d024794663 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -b5cc222d7429fe6f18c787f633d5262fac2e676f +86b69c31642e98f8357df62c09d118ad1da4e16a diff --git a/cmake/xsmm.cmake b/cmake/xsmm.cmake new file mode 100644 index 000000000000..c70d61b20ec2 --- /dev/null +++ b/cmake/xsmm.cmake @@ -0,0 +1,59 @@ +# Use LIBXSMM (make PREFIX=/path/to/libxsmm) given by LIBXSMMROOT +set(LIBXSMMROOT $ENV{LIBXSMMROOT}) +# Fetch LIBXSMM (even if LIBXSMMROOT is present) +set(LIBXSMMFETCH $ENV{LIBXSMMFETCH}) + +if(LIBXSMMROOT AND NOT LIBXSMMFETCH) + message(STATUS "Found LIBXSMM (${LIBXSMMROOT})") + file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/include/libxsmm/*.c) + list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/include/libxsmm/libxsmm_generator_gemm_driver.c) +else() + message(STATUS "Fetching LIBXSMM") + include(FetchContent) + + FetchContent_Declare( + xsmm + URL https://github.com/libxsmm/libxsmm/archive/89bca96616d657d46787c6bfd3244f8a7f213855.tar.gz + URL_HASH SHA256=54ba72ad80dbe9db60e67649f8e8e08a205635621502cc1b464fb909e619474f + ) + + FetchContent_GetProperties(xsmm) + if(NOT xsmm_POPULATED) + FetchContent_Populate(xsmm) + endif() + + set(LIBXSMMROOT ${xsmm_SOURCE_DIR}) +endif() + +if(NOT XSMM_SRCS) + file(GLOB XSMM_SRCS LIST_DIRECTORIES false CONFIGURE_DEPENDS ${LIBXSMMROOT}/src/*.c) + list(REMOVE_ITEM XSMM_SRCS ${LIBXSMMROOT}/src/libxsmm_generator_gemm_driver.c) +endif() + +set(XSMM_INCLUDE_DIRS ${LIBXSMMROOT}/include) + +add_mlir_library(xsmm SHARED ${XSMM_SRCS}) +target_include_directories(xsmm PUBLIC + $ + $ +) +add_definitions(-DLIBXSMM_DEFAULT_CONFIG -U_DEBUG -D__BLAS=0) + +set_property(TARGET xsmm PROPERTY POSITION_INDEPENDENT_CODE ON) # -fPIC +set_property(TARGET xsmm PROPERTY COMPILE_WARNING_AS_ERROR ON) + +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +target_link_libraries(xsmm PUBLIC Threads::Threads) +target_link_libraries(xsmm PUBLIC ${CMAKE_DL_LIBS}) + +include(CheckLibraryExists) +check_library_exists(m sqrt "" XSMM_LIBM) +if(XSMM_LIBM) + target_link_libraries(xsmm PUBLIC m) +endif() +check_library_exists(rt sched_yield "" XSMM_LIBRT) +if(XSMM_LIBRT) + target_link_libraries(xsmm PUBLIC rt) +endif() +#target_link_libraries(xsmm PUBLIC c) diff --git a/docs/conf.py b/docs/conf.py index eac5168d5160..ffaab561a7dc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -145,7 +145,7 @@ def documenter(app, obj, parent): autosummary_generate = True # versioning config -smv_tag_whitelist = r'^(v3.0.0)$' +smv_tag_whitelist = r'^(v3.2.0)$' smv_branch_whitelist = r'^main$' smv_remote_whitelist = None smv_released_pattern = r'^tags/.*$' diff --git a/docs/python-api/triton-semantics.rst b/docs/python-api/triton-semantics.rst index e35a355d3222..bdf25411108a 100644 --- a/docs/python-api/triton-semantics.rst +++ b/docs/python-api/triton-semantics.rst @@ -14,9 +14,7 @@ The algorithm is as follows: 2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32`` -3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32`` - - 3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``. +3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16`` 4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32`` diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index ecd0fb3b94b6..415091a1000e 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -59,6 +59,7 @@ Linear Algebra Ops :nosignatures: dot + dot_scaled Memory/Pointer Ops diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 3a488e65ed03..91bc895b2050 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -18,6 +18,12 @@ namespace mlir { namespace triton { class AllocationAnalysis; +/// Callback to allow backends to specify target-specific scratch sizes for +/// some operations. +using AllocationAnalysisScratchSizeFn = std::function; + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op); + // To convert a tensor from one layout to another, we need to allocate a // temporary buffer (i.e., scratch buffer) in shared memory. The conversion may // require multiple iterations, with each iteration involving multiple @@ -102,7 +108,8 @@ class Allocation { explicit Allocation(Operation *operation) : operation(operation) {} /// Runs allocation analysis on the given top-level operation. - void run(FuncAllocMapT &funcAllocMap); + void run(FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter); /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } @@ -173,8 +180,8 @@ class Allocation { private: /// A class that represents a shared memory buffer struct BufferT { - /// Explicit: triton_gpu.local_alloc - /// Scratch: triton_gpu.convert_layout + /// Explicit: ttg.local_alloc + /// Scratch: ttg.convert_layout /// Virtual: triton.call enum class BufferKind { Explicit, Scratch, Virtual }; @@ -250,7 +257,9 @@ class ModuleAllocation : public CallGraph { public: using FuncOffsetMapT = DenseMap; - explicit ModuleAllocation(ModuleOp moduleOp) + ModuleAllocation(ModuleOp moduleOp, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = + triton::defaultAllocationAnalysisScratchSizeFn) : CallGraph(moduleOp) { walk( // Pre-order edge walk callback @@ -259,7 +268,7 @@ class ModuleAllocation : public CallGraph { [&](FunctionOpInterface funcOp) { auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); if (inserted) - iter->second.run(funcMap); + iter->second.run(funcMap, scratchSizeGetter); }); } diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index aad4503b4840..1bf9c8a690dc 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -27,11 +27,12 @@ class AxisInfo { public: AxisInfo() : AxisInfo({}, {}, {}) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy) : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, - std::optional constantValue) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy, std::optional constantValue) : contiguity(contiguity), divisibility(divisibility), constancy(constancy), constantValue(constantValue) { assert(divisibility.size() == contiguity.size()); diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index ae05e2049834..1a33904a354c 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -5,7 +5,9 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" namespace mlir { @@ -152,6 +154,19 @@ class ScanLoweringHelper { SmallVector srcElementTypes; }; +// Helper class for lowering `tt.gather` operations. This class shares lowering +// logic between shared memory allocation and LLVM codegen. +class GatherLoweringHelper { +public: + GatherLoweringHelper(triton::GatherOp gatherOp); + + // Get the shared memory scratch size required by this op. + unsigned getScratchSizeInBytes(); + +private: + triton::GatherOp gatherOp; +}; + // Decomposes a reshape into simpler pieces. // // As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. @@ -189,6 +204,14 @@ bool supportMMA(triton::DotOp op, int version); bool supportMMA(Value value, int version); +// Conversion from `srcTy` to `dstTy` involving the minimum amount of data +// transfer provided that both types can be converted to LL (if it can't it'll +// return nullopt). The output will be such that layout.getInDimNames() == +// layout.getOutDimNames() and the conversion will not include kBlock (resp. +// kWarp or kLane) if it can be avoided +std::optional +minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy); + // Conversion from `srcTy` to `dstTy` only involves reordering of registers. // There is no need for data exchange across threads, warps, or blocks. bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy); @@ -203,11 +226,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); - -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); - -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 22c8f9c8a330..c37917a35d82 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -15,17 +15,6 @@ namespace mlir::triton { namespace gpu { -SmallVector reorderValues(const SmallVector &values, Type inType, - Type ouType); - -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - Type getElementType(Value value); class MultipleOperandsRange @@ -187,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -209,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } it += curr.size(); } - if (op->getNumOperands() > 0) { - auto argTy = op->getOperand(0).getType(); - resultVals = reorderValues(resultVals, argTy, resultTy); - } resultVals = maybeDeduplicate(op, resultVals); - resultVals = - packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/include/triton/Conversion/TritonGPUToLLVM/Passes.h index b013f26289ce..8dfd01fc91ec 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -20,6 +20,8 @@ namespace triton { namespace gpu { std::unique_ptr> createAllocateSharedMemoryPass(); +std::unique_ptr createTritonGPUGlobalScratchAllocationPass(); + } // namespace gpu #define GEN_PASS_REGISTRATION diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index 04ced17670d0..3a2686ba513a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -15,4 +15,18 @@ def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()"; } +def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> { + let summary = "Assign global scratch memory allocation"; + + let description = [{ + Decide on global scratch space memory allocation and assign attributes to each allocation. + }]; + + let constructor = "mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + #endif diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 29aec5904e8e..d6530b093346 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; constexpr int patternBenefitClampOptimizedPattern = 20; constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +struct BackendCallbacks { + /** + * A backend-specific callback for appending auxiliary data during + * `LocalStoreOp` conversion. + * + * @param[in] op The reference to the re-written `LocalStoreOp`. + * @param[in] count The number of issued LLVM instructions. + * @param[in] type The input type of issued LLVM instructions. + */ + std::function + localStoreOpConversion = nullptr; +}; + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, @@ -74,6 +92,10 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 45cfbbd181c7..87db94f25f1d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -4,6 +4,7 @@ #include "triton/Conversion/MLIRTypes.h" namespace mlir::triton { + class TargetInfoBase { public: virtual bool supportMaximumMinimum() const = 0; @@ -37,6 +38,12 @@ class TargetInfoBase { pred); } + virtual bool canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const = 0; + virtual void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const = 0; @@ -82,6 +89,8 @@ class TargetInfoBase { virtual int getSharedAddressSpace() const = 0; + virtual bool supportVectorizedAtomics() const = 0; + virtual ~TargetInfoBase() {} }; } // namespace mlir::triton diff --git a/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h index 5ae547c39218..60c0ed7b61e8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -18,11 +18,12 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis = nullptr); - Type getElementTypeForStruct(TensorOrMemDesc type); + Type getElementTypeForStruct(triton::gpu::TensorOrMemDesc type); Type convertTritonPointerType(triton::PointerType type); Type convertTritonTensorType(RankedTensorType type, const TargetInfoBase &targetInfo); - Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo); + Type convertMemDescType(triton::gpu::MemDescType type, + const TargetInfoBase &targetInfo); Type convertAsyncToken(triton::gpu::AsyncTokenType type); }; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03ae..a1c37efb52f1 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -10,9 +10,11 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" @@ -367,8 +369,9 @@ inline bool isKernel(FunctionOpInterface funcOp) { inline Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { + // See NOTE: [Additional Function Arguments] if (!isKernel(funcOp)) { - return funcOp.getArgument(funcOp.getNumArguments() - 1); + return funcOp.getArgument(funcOp.getNumArguments() - 2); } auto mod = funcOp->getParentOfType(); @@ -377,6 +380,58 @@ inline Value getStackPointer(RewriterBase &rewriter, return rewriter.create(funcOp.getLoc(), globalBase); } +inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + FunctionOpInterface funcOp, + Value allocOffset = {}) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + // Base for this function + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + if (!allocOffset) { + return gmemBase; + } + + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1); + return gep(ptrTy, i8_ty, gmemBase, allocOffset); + } + + // Base for entire kernel + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto allocSizeAttr = mod.getOperation()->getAttrOfType( + "ttg.global_scratch_memory_size"); + if (!allocSizeAttr) { + return gmemBase; + } + + Value gridIdx[3]; + Value gridDim[2]; + for (int k = 0; k < 3; ++k) { + gridIdx[k] = rewriter.create(loc, k); + } + for (int k = 0; k < 2; ++k) { + gridDim[k] = rewriter.create(loc, k); + } + + Value linearId = gridIdx[2]; + for (int k = 0; k < 2; ++k) { + linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k])); + } + + auto allocSize = allocSizeAttr.getValue().getZExtValue(); + + Value offset = mul(linearId, i32_val(allocSize)); + if (allocOffset) { + offset = add(offset, allocOffset); + } + + auto *ctx = rewriter.getContext(); + auto res = + gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); + return res; +} + inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), @@ -391,6 +446,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } + +// ----------------------------------------------------------------------- +// MXFP utilities +// ----------------------------------------------------------------------- + +// Convert each value, which is an int8 containing 2 packed mxfp4 values, +// into 2 standalone bf16 values +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values); + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale); + } // namespace LLVM /* ------------------------------------ */ @@ -453,15 +521,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); - auto order = blockedLayout.getOrder(); + auto threadOrder = blockedLayout.getThreadOrder(); + auto warpOrder = blockedLayout.getWarpOrder(); auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); // delinearize threadId to get the base index SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { @@ -530,122 +599,6 @@ emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, // Mma layout indices // ----------------------------------------------------------------------- -inline SmallVector -emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter, - const NvidiaMmaEncodingAttr &mmaLayout, - RankedTensorType type) { - auto shape = type.getShape(); - auto wpt = mmaLayout.getWarpsPerCTA(); - static constexpr std::array fpw{{2, 2, 1}}; - auto [isARow, isBRow, isAVec4, isBVec4, _] = - mmaLayout.decodeVoltaLayoutStates(); - - Value thread = getThreadId(rewriter, loc); - auto *ctx = thread.getContext(); - Value _1 = i32_val(1); - Value _2 = i32_val(2); - Value _4 = i32_val(4); - Value _16 = i32_val(16); - Value _32 = i32_val(32); - Value _fpw0 = i32_val(fpw[0]); - Value _fpw1 = i32_val(fpw[1]); - - // A info - auto aRep = mmaLayout.getMMAv1Rep(0); - auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); - // B info - auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); - auto bRep = mmaLayout.getMMAv1Rep(1); - - SmallVector rep({aRep[0], bRep[1]}); - SmallVector spw({aSpw[0], bSpw[1]}); - SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); - - Value lane = urem(thread, _32); - Value warp = udiv(thread, _32); - - Value warp0 = urem(warp, i32_val(wpt[0])); - Value warp12 = udiv(warp, i32_val(wpt[0])); - Value warp1 = urem(warp12, i32_val(wpt[1])); - - // warp offset - Value offWarpM = mul(warp0, i32_val(spw[0])); - Value offWarpN = mul(warp1, i32_val(spw[1])); - // quad offset - Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); - Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); - // pair offset - Value offPairM = udiv(urem(lane, _16), _4); - offPairM = urem(offPairM, _fpw0); - offPairM = mul(offPairM, _4); - Value offPairN = udiv(urem(lane, _16), _4); - offPairN = udiv(offPairN, _fpw0); - offPairN = urem(offPairN, _fpw1); - offPairN = mul(offPairN, _4); - offPairM = mul(offPairM, i32_val(rep[0] / 2)); - offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); - offPairN = mul(offPairN, i32_val(rep[1] / 2)); - offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); - // quad pair offset - Value offLaneM = add(offPairM, offQuadM); - Value offLaneN = add(offPairN, offQuadN); - // a, b offset - Value offsetAM = add(offWarpM, offLaneM); - Value offsetBN = add(offWarpN, offLaneN); - // m indices - Value offsetCM = add(and_(lane, _1), offsetAM); - // n indices - Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); - return {offsetCM, offsetCN}; -} - -inline SmallVector> -emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout, - RankedTensorType type) { - auto shape = type.getShape(); - - auto [isARow, isBRow, isAVec4, isBVec4, _] = - mmaLayout.decodeVoltaLayoutStates(); - - // TODO: seems like the pattern below to get `rep`/`spw` appears quite often - // A info - auto aRep = mmaLayout.getMMAv1Rep(0); - auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); - // B info - auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); - auto bRep = mmaLayout.getMMAv1Rep(1); - - auto wpt = mmaLayout.getWarpsPerCTA(); - static constexpr std::array fpw{{2, 2, 1}}; - SmallVector rep({aRep[0], bRep[1]}); - SmallVector spw({aSpw[0], bSpw[1]}); - SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); - - SmallVector idxM; - for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) - for (unsigned mm = 0; mm < rep[0]; ++mm) - idxM.push_back(m + mm * 2); - - SmallVector idxN; - for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { - for (int nn = 0; nn < rep[1]; ++nn) { - idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]); - idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1); - } - } - - SmallVector> ret; - for (unsigned x1 : idxN) { // N - for (unsigned x0 : idxM) { // M - SmallVector idx(2); - idx[0] = x0; // M - idx[1] = x1; // N - ret.push_back(std::move(idx)); - } - } - return ret; -} - inline SmallVector> emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout, RankedTensorType type) { @@ -1111,9 +1064,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, blockedLayout, type); } else if (auto mmaLayout = mlir::dyn_cast(layout)) { - if (mmaLayout.isVolta()) - result = - emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type); if (mmaLayout.isAmpere() || mmaLayout.isHopper()) result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout, type); @@ -1173,8 +1123,19 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } +// Emit code to compute the (blockId, warpId, laneId) for the current thread. +std::tuple +emitHardwareTuple(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, bool withCTAOffset, + unsigned threadsPerWarp); + // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. +// +// For example, for a thread a owns `elemsPerThread` elements of a tensor with +// type `type` and layout `layout`, the result will contain `elemsPerThread` +// vectors. Each vector contains the SSA values of the indices required to +// access the corresponding element, starting from the inner dimension. SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset); @@ -1192,8 +1153,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, // // Returns true on success. [[nodiscard]] bool emitTransferBetweenRegistersAndShared( - RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, Value shmemBase, + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, + Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback); @@ -1361,16 +1322,17 @@ inline DenseMap getSwizzledSharedPtrs( } SmallVector loadSharedToDistributed(RankedTensorType dstTy, - MemDescType srcTy, Type elemLlvmTy, + triton::gpu::MemDescType srcTy, + Type elemLlvmTy, SharedMemoryObject smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target); +void storeDistributedToShared( + triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, @@ -1468,18 +1430,6 @@ inline Value packLLVector(Location loc, ValueRange vals, return vec; } -inline bool isLayoutMmaV1(Attribute layout) { - bool isMmaV1 = false; - if (auto mmaLayout = dyn_cast(layout)) { - isMmaV1 = mmaLayout.isVolta(); - } - if (auto sliceLayout = dyn_cast(layout)) { - isMmaV1 = isa(sliceLayout.getParent()) && - cast(sliceLayout.getParent()).isVolta(); - } - return isMmaV1; -} - } // namespace mlir #endif diff --git a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt index 99d90c4d75e6..51ad71b4c2f8 100644 --- a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) -add_public_tablegen_target(TritonConversionPassIncGen) +add_public_tablegen_target(TritonConversionToGPUPassIncGen) diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.h b/include/triton/Conversion/TritonToTritonGPU/Passes.h index e159406b3ed4..112269bfb369 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.h +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_PASSES_H -#define TRITON_CONVERSION_PASSES_H +#ifndef TRITON_CONVERSION_TO_GPU_PASSES_H +#define TRITON_CONVERSION_TO_GPU_PASSES_H #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.td b/include/triton/Conversion/TritonToTritonGPU/Passes.td index f20c3604090e..81dc45a9ae59 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.td +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_PASSES -#define TRITON_CONVERSION_PASSES +#ifndef TRITON_CONVERSION_TO_GPU_PASSES +#define TRITON_CONVERSION_TO_GPU_PASSES include "mlir/Pass/PassBase.td" diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index 78917fdfdd7e..ad8e6404132d 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -12,11 +12,11 @@ template class OperationPass; namespace triton { -constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; -constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; -constexpr static char AttrTargetName[] = "triton_gpu.target"; +constexpr static char AttrNumWarpsName[] = "ttg.num-warps"; +constexpr static char AttrNumCTAsName[] = "ttg.num-ctas"; +constexpr static char AttrTargetName[] = "ttg.target"; -constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; +constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp"; // Create the pass with numWarps passed from cl::opt. std::unique_ptr> createConvertTritonToTritonGPUPass(); diff --git a/include/triton/Dialect/CMakeLists.txt b/include/triton/Dialect/CMakeLists.txt index 6ef40db00f52..c964bdcea534 100644 --- a/include/triton/Dialect/CMakeLists.txt +++ b/include/triton/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Triton) +add_subdirectory(TritonCPU) add_subdirectory(TritonGPU) add_subdirectory(TritonNvidiaGPU) diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index f682f54a1c44..fecd5adf6219 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -20,8 +20,8 @@ set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) -set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) -mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index b1f1597c5aa7..56a1aa7032fd 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -13,6 +13,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" #include "triton/Dialect/Triton/IR/Traits.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -77,6 +78,16 @@ class DialectInferLayoutInterface Attribute operandEncodingB) const = 0; }; +class DialectVerifyTensorLayoutInterface + : public DialectInterface::Base { +public: + DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + verifyTensorLayout(Attribute layout, RankedTensorType type, ModuleOp module, + function_ref emitError) const = 0; +}; + } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/Triton/IR/OpInterfaces.h b/include/triton/Dialect/Triton/IR/OpInterfaces.h new file mode 100644 index 000000000000..1489422d3e25 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -0,0 +1,21 @@ +#ifndef TRITON_IR_OP_INTERFACES_H_ +#define TRITON_IR_OP_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +namespace triton { + +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op); + +} // namespace impl + +} // namespace triton +} // namespace mlir + +#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc" + +#endif // TRITON_IR_OP_INTERFACES_H_ diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 7f0e5109e6b9..804b1648e943 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -69,9 +69,9 @@ class DotLike : public TraitBase { static LogicalResult verifyTrait(Operation *op) { if (op->getNumOperands() < 3) return op->emitOpError("expected at least 3 operands"); - auto aTy = cast(op->getOperand(0).getType()); - auto bTy = cast(op->getOperand(1).getType()); - auto cTy = cast(op->getOperand(2).getType()); + auto aTy = cast(op->getOperand(0).getType()); + auto bTy = cast(op->getOperand(1).getType()); + auto cTy = cast(op->getOperand(2).getType()); auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); auto cShape = cTy.getShape(); diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index f3159338bd0a..04e4c25fd6d8 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// Type for F8F6F4 kind of floats. -def TT_F8F6F4TypeAttr : I32EnumAttr< - "F8F6F4Type", "", +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", [ I32EnumAttrCase<"E4M3", 0, "e4m3">, I32EnumAttrCase<"E5M2", 1, "e5m2">, I32EnumAttrCase<"E2M3", 2, "e2m3">, I32EnumAttrCase<"E3M2", 3, "e3m2">, - I32EnumAttrCase<"E2M1", 4, "e2m1"> + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16"> ]>{ let cppNamespace = "::mlir::triton"; diff --git a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td new file mode 100644 index 000000000000..4208f966b357 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -0,0 +1,36 @@ +#ifndef TRITON_OP_INTERFACES +#define TRITON_OP_INTERFACES + +include "mlir/IR/OpBase.td" + + +def TransposeOpInterface : OpInterface<"TransposeOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a transpose. + It provides methods to access common properties such as the order attribute and the source operand. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Get the source operand of the transposition. + }], + /*retType=*/"::mlir::Value", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/[{ + Get the order of the transposition. + }], + /*retType=*/"::mlir::ArrayRef", + /*methodName=*/"getOrder", + /*args=*/(ins)> + ]; + + let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }]; +} + + +#endif // TRITON_OP_INTERFACES diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d3bb95ca959c..7ac54764066e 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -8,16 +8,12 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface -include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface -include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" // @@ -44,8 +40,7 @@ class TT_Op traits = []> : def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast int64 to pointer"; let arguments = (ins TT_I64Like:$src); @@ -58,8 +53,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast pointer to int64"; let arguments = (ins TT_PtrLike:$src); @@ -73,8 +67,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast between types of the same bitwidth"; let arguments = (ins TT_Type:$src); @@ -86,10 +79,10 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, // TODO: Add verifier } -def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, +def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise, + SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Floating point casting for custom types"; let description = [{ @@ -99,15 +92,17 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, }]; let arguments = ( - ins TT_FloatTensor:$src, + ins TT_FloatLike:$src, OptionalAttr:$rounding ); - let results = (outs TT_FloatTensor:$result); + let results = (outs TT_FloatLike:$result); let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; let hasVerifier = 1; + + let hasFolder = 1; } // @@ -115,8 +110,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, // def TT_ClampFOp : TT_Op<"clampf", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Clamp operation for floating point types"; let description = [{ @@ -146,8 +141,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise, // def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Precise sqrt for floating point types"; let description = [{ @@ -162,8 +157,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, } def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Precise div for floating point types"; let description = [{ @@ -178,8 +173,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, } def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Most significant N bits of the 2N-bit product of two integers"; let description = [{ @@ -197,17 +192,18 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, // Pointer Arith Ops // def TT_AddPtrOp : TT_Op<"addptr", - [Pure, - Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">]> { + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); let results = (outs TT_PtrLike:$result); let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; + let hasFolder = 1; } def TT_AdvanceOp : TT_Op<"advance", @@ -542,6 +538,7 @@ def TT_SplitOp : TT_Op<"split", [ } def TT_TransOp : TT_Op<"trans", [Pure, + TransposeOpInterface, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { @@ -575,16 +572,15 @@ def TT_TransOp : TT_Op<"trans", [Pure, }]; let arguments = ( - ins TT_TensorOrMemDesc:$src, + ins TT_Tensor:$src, DenseI32ArrayAttr:$order ); - let results = (outs TT_TensorOrMemDesc:$result); + let results = (outs TT_Tensor:$result); let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; let hasFolder = 1; - let hasVerifier = 1; } // @@ -673,9 +669,10 @@ def TT_DotOp : TT_Op<"dot", [Pure, // DotScaled Op // def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, - DotLike, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { + AttrSizedOperandSegments, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { let summary = "dot_scaled"; let description = [{ @@ -685,23 +682,23 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, let arguments = ( ins - // inputs are integer types as they are packed types and we currently - // don't have a representation for those. - TT_IntTensor:$lhs, - TT_IntTensor:$rhs, + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$lhs, + RankedTensorOf<[TT_Float,I8]>:$rhs, TT_FloatTensor:$c, - TT_IntTensor:$lhs_scale, - Optional:$rhs_scale, - TT_F8F6F4TypeAttr:$lhs_type, - TT_F8F6F4TypeAttr:$rhs_type + Optional>:$lhs_scale, + Optional>:$rhs_scale, + TT_ScaleDotElemTypeAttr:$lhs_type, + TT_ScaleDotElemTypeAttr:$rhs_type ); let results = (outs TT_FloatTensor:$d); // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file let assemblyFormat = [{ - $lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict - `:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) + $lhs (`scale` $lhs_scale^)? `,` $rhs (`scale` $rhs_scale^)? `,` $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict + `:` type($lhs) (`,` type($lhs_scale)^)? `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) }]; } @@ -727,6 +724,10 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); }]; } @@ -774,9 +775,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return", // External Elementwise op // def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, - SameOperandsAndResultEncoding, - SameVariadicOperandSize, - DeclareOpInterfaceMethods]> { + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { let description = [{ call an external function $symbol implemented in $libpath/$libname with $args @@ -788,6 +790,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, let results = (outs TT_Type:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + } // @@ -861,6 +869,32 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> { }]; } +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + }]; + + let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + // // Print Op // @@ -891,7 +925,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { `tt.assert` takes a condition tensor and a message string. If the condition is false, the message is printed, and the program is aborted. }]; - let arguments = (ins TT_Tensor:$condition, StrAttr:$message); + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; } @@ -938,6 +972,57 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", ]; } +// +// Make Tensor Descriptor Op +// +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { + let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; + + let description = [{ + `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, + and returns a descriptor object which can be used to load/store from the tensor in global memory. + }]; + + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides + ); + + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; + + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape)> + ]; + + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } + }]; +} + +def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> { + let summary = "Reinterpret a pointer as a tensor descriptor"; + + let description = [{ + This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. + Ideally, we can remove this once the APIs are fully fleshed out. + }]; + + let arguments = (ins TT_Ptr:$rawDesc); + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = [{ + $rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result)) + }]; +} + // The following ops, including `call`, `func`, and `return` are copied and modified from // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td // We could revert it back once MLIR has a better inliner interface. @@ -1145,12 +1230,11 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable } -def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ - MemoryEffects<[MemRead]>]> { +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead]>]> { let summary = "Load from descriptor"; let description = [{ This operation will be lowered to Nvidia TMA load operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + `desc` is a tensor descriptor object. The destination tensor type and shape must match the descriptor otherwise the result is undefined. This is an escape hatch and is only there for testing/experimenting. @@ -1158,7 +1242,7 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ }]; let arguments = ( ins - TT_PtrType:$desc_ptr, + TT_TensorDescType:$desc, Variadic:$indices, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict @@ -1167,21 +1251,22 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ let results = (outs TT_Tensor:$result); let assemblyFormat = [{ - $desc_ptr `[` $indices `]` + $desc `[` $indices `]` oilist( `cacheModifier` `=` $cache | `evictionPolicy` `=` $evict ) - attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + attr-dict `:` qualified(type($desc)) `->` type($result) }]; } def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ - MemoryEffects<[MemRead, MemWrite]>]> { + MemoryEffects<[MemRead, MemWrite]>, +]> { let summary = "store value based on descriptor"; let description = [{ This operation will be lowered to Nvidia TMA store operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + `desc` is a tensor descriptor object. The shape and types of `src` must match the descriptor otherwise the result is undefined. This is an escape hatch and is only there for testing/experimenting. @@ -1189,14 +1274,14 @@ def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ }]; let arguments = ( ins - TT_PtrType:$desc_ptr, + TT_TensorDescType:$desc, TT_Tensor:$src, Variadic:$indices ); let assemblyFormat = [{ - $desc_ptr `[` $indices `]` `,` $src - attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) }]; } diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 4c709cd4420b..a70b97dbc879 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -92,53 +92,16 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; // Any Type in Triton IR def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; -// Memory descriptor type. -def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { - let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system"; +// Result type of ExperimentalMakeTensorDescriptor +def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> { + let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; - let description = [{ - Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. - If mutable memory is false that means the memory is constant and can only be allocated and stored once. - A constant memory allocation is different than a tensor as it can have multiple views and the descriptor - can be changed without changing the underlying memory. - }]; - - let parameters = (ins - ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType, - "Attribute":$encoding, - "Attribute":$memorySpace, - "bool":$mutable_memory - ); - let extraClassDeclaration = [{ - MemDescType cloneWith(std::optional> shape, - Type elementType) const { - return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); - } - - bool hasRank() const { return true; } + let description = [{ + A portable abstraction for nvidia-TMA descriptors. }]; - let builders = [ - TypeBuilderWithInferredContext<(ins - "llvm::ArrayRef":$shape, - "Type":$elementType, - "Attribute":$encoding, - "Attribute":$memorySpace - ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); - }]>, - TypeBuilderWithInferredContext<(ins - "llvm::ArrayRef":$shape, - "Type":$elementType, - "Attribute":$encoding, - "Attribute":$memorySpace, - "bool":$mutableMemory - ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); - }]> - ]; - let hasCustomAssemblyFormat = 1; -} + let parameters = (ins "RankedTensorType":$blockType); + let assemblyFormat = "`<` $blockType `>`"; +} #endif diff --git a/include/triton/Dialect/Triton/IR/Types.h b/include/triton/Dialect/Triton/IR/Types.h index 74fa4ba961ac..6bcac9522ec3 100644 --- a/include/triton/Dialect/Triton/IR/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -8,8 +8,6 @@ #define GET_TYPEDEF_CLASSES #include "triton/Dialect/Triton/IR/Types.h.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc" - namespace mlir { namespace triton { diff --git a/include/triton/Dialect/TritonCPU/CMakeLists.txt b/include/triton/Dialect/TritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/include/triton/Dialect/TritonCPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/include/triton/Dialect/TritonCPU/IR/Attributes.h b/include/triton/Dialect/TritonCPU/IR/Attributes.h new file mode 100644 index 000000000000..7d4b98019d50 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Attributes.h @@ -0,0 +1,9 @@ +#ifndef TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ + +#include "triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ diff --git a/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..ace7d4ee7439 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonCPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_cpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_cpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_cpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_cpu) +add_mlir_doc(TritonCPUDialect TritonCPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonCPUOps TritonCPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonCPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonCPUAttrDefs.td) +mlir_tablegen(TritonCPUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonCPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonCPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonCPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonCPUAttrDefsIncGen) diff --git a/include/triton/Dialect/TritonCPU/IR/Dialect.h b/include/triton/Dialect/TritonCPU/IR/Dialect.h new file mode 100644 index 000000000000..e8e8de322bb4 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Dialect.h @@ -0,0 +1,17 @@ +#ifndef TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonCPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Attributes.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonCPU/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.h.inc" + +#endif // TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td new file mode 100644 index 000000000000..57f6c7c9bd71 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td @@ -0,0 +1,24 @@ +#ifndef TRITONCPU_ATTRDEFS +#define TRITONCPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonCPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonCPU_AttrTrait : AttrInterface<"TritonCPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::cpu"; +} + +class TritonCPU_Attr traits = [], + Dialect dialect = TritonCPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{TritonCPU attr.}]; + let attrName = "triton.cpu." # attrMnemonic; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td new file mode 100644 index 000000000000..260db2743046 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -0,0 +1,32 @@ +#ifndef TRITONCPU_DIALECT +#define TRITONCPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonCPU_Dialect : Dialect { + let name = "triton_cpu"; + + let cppNamespace = "::mlir::triton::cpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton CPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "tensor::TensorDialect", + "mlir::memref::MemRefDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h b/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h new file mode 100644 index 000000000000..de27597a76ef --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h @@ -0,0 +1,6 @@ +#ifndef TRITON_CPU_DIALECT_INTERFACES_H +#define TRITON_CPU_DIALECT_INTERFACES_H + +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrInterfaces.h.inc" + +#endif // TRITON_CPU_DIALECT_INTERFACES_H diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td new file mode 100644 index 000000000000..b58fd9320354 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -0,0 +1,194 @@ +#ifndef TRITONCPU_OPS +#define TRITONCPU_OPS + +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUTypes.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +class TTC_Op traits = []> : + Op { +} + +// +// External Elementwise op +// +def TTC_ExternElementwiseOp : TTC_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { + + let description = [{ + Similar to TT_ExternElementwiseOp, but only supports calls to libsleef at the moment. + The string "%s(numel)" in $symbol will be interpolated with the number of elements of + the vector argument(s). + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TTC_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; +} + +def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { + let summary = "Extract base memref from a block pointer"; + + let description = [{ + Extract base memref from a block pointer. It covers whole base tensor memory, + not only the block referenced. Base pointer, shape, and strides are used + in the resulting memref. Offsets and block shape are ignored. + + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { + let summary = "Extract indices from a block pointer."; + + let description = [{ + Extract indices that can be used to access the block using its base memref. + Indices are supposed to be used for vector loads/stores with the base + memref extracted from the same block pointer. + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs Variadic:$result); + + let builders = [ + OpBuilder<(ins "Value":$src)> + ]; + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> { + let summary = "Build a memref for a pointer."; + + let description = [{ + Build memref with static shape, offset, strides, and specified base pointer. + }]; + + let arguments = (ins TT_Ptr:$src); + + let results = (outs AnyStaticShapeMemRef:$result); + + let hasCanonicalizer = 0; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + + +def TTC_LoadOp : TTC_Op<"load", [ + MemoryEffects<[MemRead]>, +]> { + let summary = "Load from a memref to triton tensor"; + + let description = [{ + Operation to allow load from allocated temporary buffer to triton tensor. + }]; + + let arguments = (ins AnyMemRef:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_StoreOp : TTC_Op<"store", [ + MemoryEffects<[MemWrite]>, +]> { + let summary = "Store triton tensor to memref"; + + let description = [{ + Operation to allow store triton tensor to allocated temporary buffer. + }]; + + let arguments = ( + ins + TT_Type:$src, + AnyMemRef:$dst + ); + + let assemblyFormat = "$src `,` $dst attr-dict `:` type($src) `,` type($dst)"; +} + +def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { + let summary = "Print at most a single scalar or vector (converted from tensor) on each line"; + + let description = [{ + For converting tensor types to vector types. + It only takes a single scalar or vector (tensor) element. + }]; + + let arguments = (ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$val, + DenseI32ArrayAttr:$isSigned + ); + + let assemblyFormat = [{ + $prefix attr-dict (`:` $val^ `:` type($val))? + }]; + + let hasVerifier = 1; +} + +def TTC_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "For correctness checking"; + let description = [{ + Takes a condition tensor, a message string, a file string, a function string, and a line number. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins I1:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + +def TTC_DotOp : TTC_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{Same as tt.dot but on vectors.}]; + + let arguments = ( + ins + TTC_Vector:$a, + TTC_Vector:$b, + TTC_Vector:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TTC_Vector:$d); + + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td new file mode 100644 index 000000000000..d6ac013804c8 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -0,0 +1,31 @@ +#ifndef TRITONCPU_TYPES +#define TRITONCPU_TYPES + +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTC_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTC_TokenType : TTC_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +def TTC_Vector : VectorOf<[TT_Float, TT_Int]>; + +def TTC_Type : AnyTypeOf<[TT_Float, TT_Int, TTC_Vector]>; + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/Types.h b/include/triton/Dialect/TritonCPU/IR/Types.h new file mode 100644 index 000000000000..e8c984628aa5 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONCPU_IR_TYPES_H_ +#define TRITONCPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/Attributes.h b/include/triton/Dialect/TritonGPU/IR/Attributes.h index a99ddfc17d22..1f93b3d935f2 100644 --- a/include/triton/Dialect/TritonGPU/IR/Attributes.h +++ b/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -5,6 +5,6 @@ #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #define GET_ATTRDEF_CLASSES -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" +#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc" #endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index 73c9401c18ed..a211c7bc8751 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,21 +1,26 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg) add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) -mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) -mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(TritonGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) +mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(TritonGPUTypeInterfacesIncGen) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 74ea99b58891..85c789635a96 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -9,10 +9,10 @@ // TritonGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Types.h" #define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" namespace mlir { @@ -76,9 +76,8 @@ SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); // Returns the dimensions of the tensor from minor (fast-varying) to -// major (slow-varying). For blocked, mma, and dotOperand layouts, -// though the elements are in registers, the order refers to memory -// layout of the original tensor in global memory. +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. // For shared Layout, the order refers to which dimension of the original tensor // is contiguous in shared memory. SmallVector getOrder(Attribute layout); @@ -117,9 +116,7 @@ SmallVector getCTAOrder(Attribute layout); * (3) In the implementation of emitIndices, ShapePerCTATile will * be replicated or wrapped to fit ShapePerCTA. */ -SmallVector -getShapePerCTATile(Attribute layout, - ArrayRef tensorShape = ArrayRef()); +SmallVector getShapePerCTATile(Attribute layout); SmallVector getShapePerCTA(ArrayRef CTASplitNum, ArrayRef shape); @@ -130,6 +127,17 @@ unsigned getNumWarpsPerCTA(Attribute layout); unsigned getNumCTAs(Attribute layout); +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kMajor +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor); + bool isExpensiveCat(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. @@ -141,6 +149,75 @@ triton::gpu::BlockedEncodingAttr getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, int numWarps, int threadsPerWarp, int numCTAs); +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two, +// which is the case for DistributedLayouts If broadcastRegisters is false, we +// remove any register that's larger than the desired shape. In the example +// above we would have +// L(register=1) = 4 +// L(register=2) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters = true); + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape); + +SmallVector standardOutDimNames(MLIRContext *ctx, int rank); +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order); + // Dump information about which threads/registers contain each of the tensor // elements. void dumpLayout(RankedTensorType tensorType); diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 1367f65a031f..7c81b2496cdf 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -44,10 +44,6 @@ std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth = std::nullopt); -// Given a linear layout with input dims and output dims containing a "block" -// dimension, determines if the layout moves data across block boundaries. -bool isCrossCTAConversion(const LinearLayout &layout); - // Given a linear layout where the input dimensions contain a "block" dimension, // this method sets the "block" dimension to 0 and removes the corresponding // output dimensions. @@ -245,11 +241,11 @@ LinearLayout chooseShemLayoutForRegToRegConversion( // TODO(Keren): We should replace tensorTy with a LinearLayout and the element // bit width of the tensor in the future to support more flexible tensor // encodings -std::optional -chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, - ArrayRef repShape, - ArrayRef paddedRepShape, - ArrayRef order, int swizzleByteSize); +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8512fce57fa..b900c3d2e3b7 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -2,8 +2,8 @@ #define TRITONGPU_ATTRDEFS include "mlir/IR/AttrTypeBase.td" -include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" //===----------------------------------------------------------------------===// // TritonGPU Attribute Definitions @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute code extraBaseClassDeclaration = [{ unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; - ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; }]; } @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to let genVerifyDecl = 1; let skipDefaultBuilders = 1; } - //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// @@ -346,23 +344,8 @@ compared to 1*64 when the hasLeadingOffset is false. // index of the inner dimension in `order` unsigned inner = (opIdx == 0) ? 0 : 1; - // ---- begin Volta ---- - if (mmaEnc.isVolta()) { - int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8)); - perPhase = std::max(perPhase, 1); - bool is_row = order[0] != 0; - bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) : - is_row && (shapePerCTA[order[0]] <= 16); - int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : - ((is_row && !is_vec4) ? 2 : 1); - int rep = 2 * pack_size; - int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; - int vec = 2 * rep; - return get(context, vec, perPhase, maxPhase, order, CTALayout); - } - - // ---- begin Ampere ---- - if (mmaEnc.isAmpere()) { + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; @@ -377,7 +360,7 @@ compared to 1*64 when the hasLeadingOffset is false. int k = (needTrans) ? matShape[0] : matShape[2]; int vec = (order[0] == rank-1) ? k : m; int mmaStride = (order[0] == rank-1) ? m : k; - int maxPhase = mmaStride / perPhase; + int maxPhase = std::max(mmaStride / perPhase, 1); return get(context, vec, perPhase, maxPhase, order, CTALayout); } @@ -390,20 +373,13 @@ compared to 1*64 when the hasLeadingOffset is false. int k = needTrans ? matShape[1] : matShape[2]; int vec = (order[0] == rank-1) ? n : k; int mmaStride = (order[0] == rank-1) ? k : n; - int maxPhase = mmaStride / perPhase; + int maxPhase = std::max(mmaStride / perPhase, 1); return get(context, vec, perPhase, maxPhase, order, CTALayout); } llvm_unreachable("invalid operand index"); } - // ---- begin version 3 ---- - if (mmaEnc.isHopper()) { - llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" - " is Hopper has not been implemented yet"); - return $_get(context, 1, 1, 1, order, CTALayout, true); - } - // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -481,9 +457,16 @@ layout = [0 4 8 12] [3 7 11 15] For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". }]; let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + // Interface for the meta information about the multiple thread hierarchy. InterfaceMethod<"Get the shape of the CTAs per CGA.", "SmallVector", @@ -517,11 +500,6 @@ For the Threads Per Warp and Values Per Thread level, the linear id distribution "SmallVector", "getCTASplitNum">, - InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", - "SmallVector", - "getShapePerCTATile", - (ins "ArrayRef":$tensorShape)>, - InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, @@ -570,6 +548,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getRepOrder() const; SmallVector getCTAsPerCGA() const; SmallVector getCTAOrder() const; SmallVector getCTASplitNum() const; @@ -579,12 +558,39 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getThreadOrder() const; SmallVector getSizePerThread() const; - SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; std::optional toLinearLayout(ArrayRef shape) const; }]; } +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins "LinearLayout":$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() const; + SmallVector getOrder() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + //===----------------------------------------------------------------------===// // Blocked Layout Encoding //===----------------------------------------------------------------------===// @@ -610,7 +616,7 @@ Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warp for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} @@ -636,7 +642,7 @@ Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warp [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} @@ -666,7 +672,7 @@ CTA [1,0] CTA [1,1] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] for -#triton_gpu.blocked_layout<{ +#ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} warpsPerCTA = {1, 2} @@ -770,7 +776,7 @@ for //===----------------------------------------------------------------------===// // MMA Layout Encoding //===----------------------------------------------------------------------===// -// TODO: MMAv1 and MMAv2 should be two instances of the same class + def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { let cppNamespace = "::mlir::triton::gpu"; let methods = [ @@ -779,26 +785,16 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "bool", "supportReduction">, - InterfaceMethod<"Return shape per CTA.", - "SmallVector", - "getShapePerCTATileForOperand", - (ins "ArrayRef":$tensorShape, - "int":$kWidth, - "int":$opIdx)>, - - InterfaceMethod<"Return total element size per thread for dot operands.", - "unsigned", - "getTotalElemsPerThreadForOperand", - (ins "ArrayRef":$tensorShape, - "Type":$eltTy, - "int":$kWidth, - "int":$opIdx)>, - InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperand", (ins "int":$opIdx, "int":$kWidth)>, + + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, ]; } @@ -917,10 +913,10 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -1024,11 +1020,11 @@ Row | warp 0 warp 2 return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, Type elemType, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { @@ -1136,69 +1132,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: ArrayRefParameter<"unsigned">:$instrShape ); - let builders = [ - // Specially for MMAV1(Volta) - AttrBuilder<(ins "int":$versionMajor, - "int":$numWarps, - "CTALayoutAttr":$CTALayout, - "ArrayRef":$instrShape, - "ArrayRef":$shapeC, - "bool":$isARow, - "bool":$isBRow, - "bool":$isAVec4, - "bool":$isBVec4, - "int":$id), [{ - assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); - // 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4] - int versionMinor = (isARow * (1<<0)) |\ - (isBRow * (1<<1)) |\ - (isAVec4 * (1<<2)) |\ - (isBVec4 * (1<<3)); - - // TODO: Share code with - // DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the - // rep,spw and fpw. - SmallVector wpt({1, 1}); - SmallVector wpt_nm1; - - SmallVector rep(2), spw(2); - std::array fpw{{2, 2, 1}}; - int packSize0 = (isARow || isAVec4) ? 1 : 2; - rep[0] = 2 * packSize0; - spw[0] = fpw[0] * 4 * rep[0]; - - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - rep[1] = 2 * packSize1; - spw[1] = fpw[1] * 4 * rep[1]; - - do { - wpt_nm1 = wpt; - if (wpt[0] * wpt[1] < numWarps) - wpt[0] = std::clamp(wpt[0] * 2, 1, shapeC[0] / spw[0]); - if (wpt[0] * wpt[1] < numWarps) - wpt[1] = std::clamp(wpt[1] * 2, 1, shapeC[1] / spw[1]); - } while (wpt_nm1 != wpt); - - return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape); - }]>, - - - AttrBuilder<(ins "int":$versionMajor, - "int":$numWarps, - "CTALayoutAttr":$CTALayout, - "ArrayRef":$instrShape, - "ArrayRef":$shapeA, - "ArrayRef":$shapeB, - "ArrayRef":$shapeC, - "bool":$isARow, - "bool":$isBRow, - "int":$id), [{ - assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); - bool isAVec4 = !isARow && (shapeA[isARow] <= 16); - bool isBVec4 = isBRow && (shapeB[isBRow] <= 16); - return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id); - }]> - ]; let extraClassDeclaration = extraDistributedDeclaration # [{ bool isVolta() const; @@ -1206,26 +1139,10 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: bool isAmpere() const; bool isHopper() const; - unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; - - // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor - std::tuple decodeVoltaLayoutStates() const; - - // Number of bits in versionMinor to hold the ID of the MMA encoding instance. - // Here 5 bits can hold 32 IDs in a single module. - static constexpr int numBitsToHoldMmaV1ID{5}; - - // For MMA v1, method `getMMAv1IsRow` returns whether e.g. the a operand is used - // in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation - // section 9.7.13.4.1 for more details. - bool getMMAv1IsRow(int opIdx) const; - bool getMMAv1IsVec4(int opIdx) const; - int getMMAv1NumOuter(ArrayRef shape, int opIdx) const; - SmallVector getMMAv1Rep(int opIdx) const; - SmallVector getMMAv1ShapePerWarp(int opIdx) const; - int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2RepForOperand(ArrayRef shape, - int bitwidth, int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int kWidth, + int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { @@ -1234,11 +1151,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return false; }; SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; - unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getContigPerThread() { - assert(isVolta() || isAmpere() || isHopper()); + assert(isAmpere() || isHopper()); auto rank = getWarpsPerCTA().size(); SmallVector contigPerThread(rank, 1); contigPerThread[rank - 1] = 2; @@ -1319,6 +1234,27 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) }]; let parameters = ( @@ -1329,16 +1265,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim ); let builders = [ - // Specially for MMAV1(Volta) AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent, "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); - if (!parentAttr || !parentAttr.isAmpere()) + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) return $_get(context, opIdx, parent, 0); + // For MMAV2 and V3 unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 10f2c8c68828..be8487be1e2f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -4,7 +4,7 @@ include "mlir/IR/OpBase.td" def TritonGPU_Dialect : Dialect { - let name = "triton_gpu"; + let name = "ttg"; let cppNamespace = "::mlir::triton::gpu"; @@ -17,28 +17,27 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "mlir::gpu::GPUDialect", - "tensor::TensorDialect", ]; let extraClassDeclaration = [{ - static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static std::string getNumWarpsAttrName() { return "ttg.num-warps"; } static int getNumWarps(ModuleOp mod) { - if (!mod->hasAttr("triton_gpu.num-warps")) + if (!mod->hasAttr("ttg.num-warps")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); - return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + "TritonGPU module should contain a ttg.num-warps attribute"); + return cast(mod->getAttr("ttg.num-warps")).getInt(); } static int getNumCTAs(ModuleOp mod) { - if (!mod->hasAttr("triton_gpu.num-ctas")) + if (!mod->hasAttr("ttg.num-ctas")) return 1; - return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + return cast(mod->getAttr("ttg.num-ctas")).getInt(); } void registerTypes(); - static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; } static int getThreadsPerWarp(ModuleOp mod) { - Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); + Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp"); if(!threadsPerWarp) { return 32; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h index 9cf2876d2c31..1e76237dac02 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -1,6 +1,9 @@ #ifndef TRITON_GPU_DIALECT_INTERFACES_H #define TRITON_GPU_DIALECT_INTERFACES_H + +// clang-format off #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc" +// clang-format on #endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index a290cb20310a..9aa3e0b62667 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -3,10 +3,12 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -94,7 +96,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ let arguments = ( ins TT_PtrTensor:$src, - TT_MemDescType:$result, + TTG_MemDescType:$result, Optional:$mask, Optional:$other, DefaultValuedAttr:$cache, @@ -167,7 +169,7 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; } @@ -211,23 +213,48 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { Then in Python syntax, the subview covers input[1][0:4][4:8]. }]; let arguments = ( - ins TT_MemDescType:$src, Variadic:$offsets); + ins TTG_MemDescType:$src, Variadic:$offsets); - // Use qualified() otherwise "!tt.memdesc" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; - let results = (outs TT_MemDescType:$result); + let results = (outs TTG_MemDescType:$result); let hasVerifier = 1; } +def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, + TransposeOpInterface, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "transpose the descriptor"; + + let description = [{ + This operation returns a new descriptor + representing a transposed view of the buffer. + }]; + + let arguments = (ins TTG_MemDescType:$src, Variadic:$order); + + let arguments = ( + ins TTG_MemDescType:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))"; + + let hasFolder = 1; +} + def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { let summary = "Load a buffer from local memory into a distributed tensor"; let description = [{ Load a tensor from the local memory descriptor into a distributed tensor. }]; - let arguments = (ins TT_MemDescType:$src, Optional :$token); + let arguments = (ins TTG_MemDescType:$src, Optional :$token); let builders = [ OpBuilder<(ins "Type":$retType, "Value":$src), @@ -235,7 +262,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods(nullptr)); }]>]; - // Use qualified() otherwise "!tt.memdesc" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; let results = (outs TT_Tensor:$result); @@ -247,10 +274,10 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods" is printed as "". + // Use qualified() otherwise "!ttg.memdesc" is printed as "". let assemblyFormat = [{ $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) }]; @@ -268,7 +295,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods]>]> { + let summary = "allocate a global memory buffer"; + let description = [{ + This operation allocates a buffer in global memory that is private to the current program. + }]; + let arguments = ( + ins + I32Attr:$nbytes, + I32Attr:$alignment + ); + let results = (outs TT_Ptr:$result); + + let builders = [ + OpBuilder<(ins "Type":$result, "int32_t":$nbytes, "int32_t":$alignment), + [{ build($_builder, $_state, result, + $_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]> + ]; + + let assemblyFormat = [{attr-dict `:` qualified(type($result))}]; +} + #endif diff --git a/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td similarity index 75% rename from include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td rename to include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td index e3aed226277c..a0415b62c632 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td @@ -1,11 +1,11 @@ -#ifndef TRITON_TYPE_INTERFACES -#define TRITON_TYPE_INTERFACES +#ifndef TRITON_GPU_TYPE_INTERFACES +#define TRITON_GPU_TYPE_INTERFACES include "mlir/IR/OpBase.td" // Interface dynamically attached to RankedTensorType and MemDescType. -def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { - let cppNamespace = "::mlir"; +def TTG_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir::triton::gpu"; let methods = [ InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", "mlir::Attribute", "getEncoding", (ins)>, @@ -17,8 +17,7 @@ def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { "int64_t", "getRank", (ins)>, InterfaceMethod<"Returns the element type bit width", "int64_t", "getElementTypeBitWidth", (ins)>, - ]; } -#endif // TRITON_TYPE_INTERFACES +#endif // TRITON_GPU_TYPE_INTERFACES diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td index 6765ac40cbbe..8061a98797b7 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -1,8 +1,9 @@ #ifndef TRITONGPU_TYPES #define TRITONGPU_TYPES -include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" class TTG_TypeDef traits = []> : TypeDef { @@ -23,8 +24,7 @@ def TTG_TokenType : TTG_TypeDef<"Token", "token"> { let skipDefaultBuilders = 1; } -def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", - "async.token", []> { +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", "async.token", []> { let summary = "async token type"; let description = [{ `ttg.async.token` is a type returned by an asynchronous operation. @@ -33,4 +33,69 @@ def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", }]; } +// Memory descriptor type. +def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::gpu::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + ArrayRefParameter<"int64_t">:$allocShape + ); + + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape()); + } + + bool hasRank() const { return true; } + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + "llvm::ArrayRef":$allocShape + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape); + }]> + + ]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + + #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Types.h b/include/triton/Dialect/TritonGPU/IR/Types.h index edf37fef606d..82ab3ae457d5 100644 --- a/include/triton/Dialect/TritonGPU/IR/Types.h +++ b/include/triton/Dialect/TritonGPU/IR/Types.h @@ -1,10 +1,13 @@ #ifndef TRITONGPU_IR_TYPES_H_ #define TRITONGPU_IR_TYPES_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" #define GET_TYPEDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/Types.h.inc" +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc" + #endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index f2b79d222a91..9020bf8d3994 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -23,6 +23,38 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ]; } +def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> { + let summary = "test assigning latencies to interesting ops ahead of pipelining"; + + let description = [{ + This is a test pass that tests `assignLatencies` method of `TritonGPULoopScheduling`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> { + let summary = "test scheduling a loop for software pipelining"; + + let description = [{ + This is a test pass that tests `scheduleLoop` method of `TritonGPULoopScheduling`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { let summary = "3xTF32 trick"; @@ -179,4 +211,29 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init" "mlir::triton::TritonDialect"]; } +def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> { + let summary = "Generate loop scheduling for SWP"; + + let description = "This pass sets up stages and clustering for software pipelining."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> { + let summary = "Improve coalescing for async global to local copies"; + + let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than " + "the blocked encoding's sizePerThread, this pass improves coalescing by clipping the " + "sizePerThread value"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 88f062a01023..cdf22d15d499 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -8,6 +8,11 @@ namespace mlir { namespace triton { static const char *kNumStagesAttrName = "tt.num_stages"; +static const char *kLoopStageAttrName = "loop.stage"; +static const char *kLoopClusterAttrName = "loop.cluster"; + +bool loopHasDistGreaterThanOne(scf::ForOp forOp); +bool isOuterLoop(scf::ForOp forOp); /// Function to mask operations during scheduling. Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); @@ -29,6 +34,11 @@ void addOps(scf::ForOp forOp, int stage, /// mutable. void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, Value val); + +// Return the minClusterId and maxClusterId for the given ForOp. +std::pair getMinMaxCluster(scf::ForOp &forOp); +std::pair getStageCluster(Operation *op); +void setStageCluster(Operation *op, int stage, int cluster); } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 1dd1fc686034..916c9b252267 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -11,6 +11,18 @@ namespace mlir { namespace triton { +namespace gpu { + +/// Discover operations that should become async and assign latencies to them +/// based on the numStages value provided by the user. +DenseMap assignLatencies(ModuleOp forOp, int numStages); + +/// Schedule the loop based on the latencies assigned to the operations. +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency); + +}; // namespace gpu + /// This fill out the pipelining options including schedule and annotations /// for wait ops. This also does pre-processing by converting some of the /// loads into async loads so that the IR is ready to be pipelined. @@ -100,8 +112,16 @@ class CoarseSchedule { std::vector> createFinalSchedule(scf::ForOp forOp); void dump(); + bool empty() { return opToStageAndCluster.size() == 0; } + void serialize(scf::ForOp &forOp); + // Create a CoarseSchedule based on forOp's . + void deSerialize(scf::ForOp &forOp); }; +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule); + } // namespace triton } // namespace mlir #endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index e688b52245ee..0f6bd57afaf1 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -192,6 +192,21 @@ bool isPureUnaryInlineAsm(Operation *op); // read the compute capability from the module attributes int getNVIDIAComputeCapability(Operation *module); +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); + +enum class MMALoadType { + SharedV3, + Registers, // may be v2 or v3 + DoNotPipeline, // could be a valid shared/registers MMA operand, but skip + // pipelining +}; +MMALoadType getMMALoadType(Operation *loadOp); + +// Returns composed LinearLayout for register to shared copy +std::optional +getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, int elemBitWidth); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index b7ce83fe7ea6..45c70e15c2e9 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,12 +1,12 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng) add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonNvidiaGPUTableGen) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 67ece715d2f6..ff19458e4754 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td" def TritonNvidiaGPU_Dialect : Dialect { - let name = "triton_nvidia_gpu"; + let name = "ttng"; let cppNamespace = "::mlir::triton::nvidia_gpu"; @@ -39,22 +39,21 @@ def TritonNvidiaGPU_Dialect : Dialect { "triton::TritonDialect", "triton::gpu::TritonGPUDialect", "mlir::gpu::GPUDialect", - "tensor::TensorDialect", ]; let extraClassDeclaration = [{ - static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static std::string getNumWarpsAttrName() { return "ttg.num-warps"; } static int getNumWarps(ModuleOp mod) { - if(!mod->hasAttr("triton_gpu.num-warps")) + if(!mod->hasAttr("ttg.num-warps")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-warps attribute"); - return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + "TritonGPU module should contain a ttg.num-warps attribute"); + return cast(mod->getAttr("ttg.num-warps")).getInt(); } static int getNumCTAs(ModuleOp mod) { - if(!mod->hasAttr("triton_gpu.num-ctas")) + if(!mod->hasAttr("ttg.num-ctas")) llvm::report_fatal_error( - "TritonGPU module should contain a triton_gpu.num-ctas attribute"); - return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + "TritonGPU module should contain a ttg.num-ctas attribute"); + return cast(mod->getAttr("ttg.num-ctas")).getInt(); } void registerTypes(); }]; diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 243b934367ad..f363032a3748 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -28,7 +28,8 @@ include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" -include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -80,8 +81,8 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods:$useC, DefaultValuedAttr:$inputPrecision, @@ -100,8 +101,8 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods, AllTypesMatch<["inputs", "outputs"]>]> { let summary = "warp group dot wait"; - let arguments = (ins Variadic:$inputs, I32Attr:$pendings); - let results = (outs Variadic:$outputs); + let arguments = (ins Variadic:$inputs, I32Attr:$pendings); + let results = (outs Variadic:$outputs); let description = [{ Waits until there are $pendings or fewer outstanding async dot operations. @@ -125,7 +126,7 @@ def TTNG_InitBarrierOp : TTNG_Op<"init_barrier", [DeclareOpInterfaceMethods { + let summary = "Convert tensor descriptor to pointer to tma descriptor"; + + let arguments = (ins TT_TensorDescType:$desc); + let results = (outs TT_Ptr:$ptr); + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr)) + }]; + + let builders = [ + OpBuilder<(ins "Value":$desc), [{ + auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1); + build($_builder, $_state, ptrTy, desc); + }]> + ]; + + let hasCanonicalizeMethod = 1; +} + def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods]> { let summary = "copy data based on descriptor from global memory to local memory asynchronously"; @@ -201,8 +222,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", let arguments = ( ins TT_PtrType:$desc_ptr, Variadic:$coord, - TT_MemDescType:$barrier, - TT_MemDescType:$result, + TTG_MemDescType:$barrier, + TTG_MemDescType:$result, I1:$pred, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, @@ -230,7 +251,7 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global", let arguments = ( ins TT_PtrType:$desc_ptr, Variadic:$coord, - TT_MemDescType:$src); + TTG_MemDescType:$src); let assemblyFormat = [{ $desc_ptr `[` $coord `]` $src diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 41a3621a971e..9ddec8881269 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -9,6 +9,7 @@ #include #include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -432,6 +433,7 @@ class LinearLayout { // (e.g. by reshaping) then the order doesn't really affect anything. auto getInDimNames() const { return llvm::make_first_range(bases); } auto getOutDimNames() const { return llvm::make_first_range(outDims); } + auto getOutDimSizes() const { return llvm::make_second_range(outDims); } // Gets the position that this outDim occupies in getOutDimNames(). Asserts // if the dim is not present. @@ -575,29 +577,20 @@ class LinearLayout { return *this; } - // divideLeft and divideRight are the inverses of operator*. - // - // Consider `a = c.divideRight(b)`, where `a` is a linear layout with - // `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove - // some empty dimensions from `a` to form `a'` and still have `a' * b == c`. - // Therefore, there are multiple possible values that we could return for - // `(a * b).divideRight(b)` which would satisfy - // `((a * b).divideRight(b)) * b == a * b`. - // - // In the following example, we have `a * b == a' * b` when "in1" is an empty - // dimension that maps everything to 0: - // - // a = L("in1", "in2") -> ("out1", "out2") - // a' = L("in1") -> ("out1") - // b = L("in2") -> ("out2") - // - // divideLeft and divideRight resolve this ambiguity by always returning the - // "canonical" quotient, namely the one with the fewest possible size-zero - // input and output dimensions. - // - // TODO(jlebar): Implement divideLeft. - // std::optional divideLeft(const LinearLayout &divisor); - std::optional divideRight(const LinearLayout &divisor) const; + // Returns true if this layout acts trivially (as the identity) on the given + // dimensions. This means that it's the identity on those dimensions, and it + // does not map other dimensions onto those or these onto other dimensions. + bool isTrivialOver(ArrayRef dimNames) const; + + // For an endomorphism on dimNames (linear map that maps dimNames to dimNames) + // checks whether it is the identity map on these dimensions (i.e + // LinearLayouts::isTrivialOver) and if so, returns the sublayout of the + // remaining dimensions. + // nb. The isTrivialOver condition is more restrictive than the usual + // "leaves the subspace invariant" condition in maths. + // We can always relax it if we know how to take advantage of a conversion + // layout being block-diagonal in the future. + std::optional quotient(ArrayRef dimNames) const; // Gets a layout with only these in/out dimensions. // @@ -614,10 +607,10 @@ class LinearLayout { bool sublayoutIsZero(ArrayRef inDimNames, ArrayRef outDimNames) const; - // Is the sublayout restricted to inDimNames + outDimNames and then flattened - // to 1D the identity layout (ignoring out-dim sizes)? - bool sublayoutIsIdentity(ArrayRef inDimNames, - ArrayRef outDimNames) const; + // Is the sublayout defined from dimNames to dimNames the identity? + // In particular, is the input and output size in these dimensions + // the same, and are the bases the identity? + bool squareSublayoutIsIdentity(ArrayRef dimNames) const; // Computes and returns L(x, y, z). // @@ -695,6 +688,7 @@ class LinearLayout { return !(lhs == rhs); } bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + friend size_t hash_value(const LinearLayout &layout); private: // Factory function that gracefully fails rather than asserts if the layout is diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 43e7df13585c..c0d845eb6843 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -13,6 +13,7 @@ namespace mlir::triton { inline const std::set CACHE_INVALIDATING_ENV_VARS = { // clang-format off "AMDGCN_ENABLE_DUMP", + "AMDGCN_USE_BUFFER_OPS", "DISABLE_FAST_REDUCTION", "DISABLE_LLVM_OPT", "DISABLE_MMA_V3", @@ -27,6 +28,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_DISABLE_LINE_INFO", "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_HIP_STREAM_PREFETCH", "TRITON_LLVM_DEBUG_ONLY", "USE_IR_LOC", "NVPTX_ENABLE_DUMP", diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 3840bf4199e5..020f513bacf1 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -28,7 +28,7 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation( bool pessimistic = true; auto result = op->getResult(0); // skip ops that return memdesc in a different memory space. - if (auto memdescTy = dyn_cast(result.getType())) { + if (auto memdescTy = dyn_cast(result.getType())) { if (!isa_and_nonnull( memdescTy.getMemorySpace())) return success(); @@ -38,13 +38,12 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation( if (isa(op)) { aliasInfo.insert(result); pessimistic = false; - } else if (isa(op)) { - // extract_slice %src - // trans %src + } else if (isa( + op)) { aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; } else { - assert(!isa(result.getType()) && + assert(!isa(result.getType()) && "unknown operation creating memory descriptor"); } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004df..c79e81e65ca6 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -4,9 +4,7 @@ #include #include -#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/Liveness.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Alias.h" @@ -15,19 +13,6 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" -using ::mlir::triton::gpu::AMDMfmaEncodingAttr; -using ::mlir::triton::gpu::BlockedEncodingAttr; -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getContigPerThread; -using ::mlir::triton::gpu::getOrder; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; -using ::mlir::triton::gpu::getSizePerThread; -using ::mlir::triton::gpu::getUniqueContigPerThread; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -using ::mlir::triton::gpu::SharedEncodingAttr; -using ::mlir::triton::gpu::SliceEncodingAttr; - namespace mlir { //===----------------------------------------------------------------------===// @@ -38,27 +23,6 @@ namespace triton { // Bitwidth of pointers constexpr int kPtrBitWidth = 64; -static std::pair, SmallVector> -getCvtOrder(Attribute srcLayout, Attribute dstLayout) { - auto srcMmaLayout = mlir::dyn_cast(srcLayout); - auto srcDotLayout = mlir::dyn_cast(srcLayout); - auto dstMmaLayout = mlir::dyn_cast(dstLayout); - auto dstDotLayout = mlir::dyn_cast(dstLayout); - - assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() && - !srcMmaLayout.isHopper()) && - "mma -> mma layout conversion is only supported on Ampere"); - - // mma or dot layout does not have an order, so the order depends on the - // layout of the other operand. - auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) - : getOrder(srcLayout); - auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) - : getOrder(dstLayout); - - return {inOrd, outOrd}; -} - static SmallVector getRepShapeForCvt(RankedTensorType srcTy, RankedTensorType dstTy) { Attribute srcLayout = srcTy.getEncoding(); @@ -70,15 +34,18 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, if (shouldUseDistSmem(srcLayout, dstLayout)) { // TODO: padding to avoid bank conflicts - return convertType(getShapePerCTA(srcTy)); + return convertType(gpu::getShapePerCTA(srcTy)); } assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); - auto srcShapePerCTA = getShapePerCTA(srcTy); - auto dstShapePerCTA = getShapePerCTA(dstTy); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); + auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); + auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout); + + assert(srcTy.getRank() == dstTy.getRank() && + "src and dst must have the same rank"); unsigned rank = dstTy.getRank(); SmallVector repShape(rank); @@ -113,20 +80,16 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(!isMfmaToDotShortcut(srcTy, dstTy)); + assert(cvtNeedsSharedMemory(srcTy, dstTy)); - // FIXME This is NOT entirely correct - // This should be getElemOrder, but we don't have such a method - // TODO Implement getElemOrder and make sure it's consistent with - // getContigPerThread - auto inOrd = gpu::getThreadOrder(srcLayout); - auto outOrd = gpu::getThreadOrder(dstLayout); + const auto &inOrd = gpu::getOrder(srcLayout); + const auto &outOrd = gpu::getOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = - getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; unsigned dstContigPerThread = - getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; // TODO: Fix the legacy issue that ourOrd[0] == 0 always means // that we cannot do vectorization. unsigned innerDim = rank - 1; @@ -135,17 +98,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, : srcContigPerThread; scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; - if (auto mma = mlir::dyn_cast(srcLayout)) { - if (mma.getVersionMajor() == 1) { - // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the - // codegen. - scratchConfig.inVec = srcContigPerThread; - } else if (mlir::isa(dstLayout)) { - // when storing from mma layout and loading in blocked layout vectorizing - // the load back gives better performance even if there is a - // transposition. - scratchConfig.outVec = dstContigPerThread; - } + if (mlir::isa(srcLayout) && + mlir::isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + scratchConfig.outVec = dstContigPerThread; } // No padding is required if the tensor is 1-D, or if all dimensions except @@ -158,13 +116,74 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, return scratchConfig; } +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + return helper.getScratchSizeInBytes(); + } + if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + return helper.getScratchSizeInBytes(); + } + if (auto gatherOp = dyn_cast(op)) { + GatherLoweringHelper helper(gatherOp); + return helper.getScratchSizeInBytes(); + } + if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + return std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + } + if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return 0; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); + return isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + } + if (isa(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + return 0; + } + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + auto elemTy = cast(value.getType()).getPointeeType(); + assert(!isa(elemTy) && "unexpected pointer type"); + return elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + } + if (auto createTensormap = dyn_cast(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } + return 0; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation::FuncAllocMapT *funcAllocMap, - Allocation *allocation) + Allocation *allocation, + AllocationAnalysisScratchSizeFn scratchSizeGetter) : operation(operation), funcAllocMap(funcAllocMap), - allocation(allocation) { + allocation(allocation), scratchSizeGetter(scratchSizeGetter) { run(); } @@ -186,12 +205,12 @@ class AllocationAnalysis { /// Initializes explicitly defined shared memory values for a given operation. void getExplicitValueSize(Operation *op) { for (Value result : op->getResults()) { - auto alloc = result.getDefiningOp(); + auto alloc = result.getDefiningOp(); if (alloc && alloc.isSharedMemoryAlloc()) { // Bytes could be a different value once we support padding or other // allocation policies. auto allocType = alloc.getType(); - auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); + auto shapePerCTA = gpu::getShapePerCTA(allocType); auto bytes = product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; @@ -217,79 +236,19 @@ class AllocationAnalysis { /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { - const size_t scratchAlignment = 128; - if (auto reduceOp = dyn_cast(op)) { - ReduceOpHelper helper(reduceOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto scanOp = dyn_cast(op)) { - ScanLoweringHelper helper(scanOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto histogram = dyn_cast(op)) { - auto dstTy = histogram.getType(); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( - op->getParentOfType()); - auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * - std::max(8, dstTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto cvtLayout = dyn_cast(op)) { - auto srcTy = cvtLayout.getSrc().getType(); - auto dstTy = cvtLayout.getType(); - auto srcEncoding = srcTy.getEncoding(); - auto dstEncoding = dstTy.getEncoding(); - if (mlir::isa(srcEncoding) || - mlir::isa(dstEncoding)) { - // Conversions from/to shared memory do not need scratch memory. - return; - } - // ConvertLayoutOp with both input/output non-shared_layout - // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's - // also possible to realize it with other approaches in restricted - // conditions, such as warp-shuffle - auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); - auto elems = getNumScratchElements(scratchConfig.paddedRepShape); - auto bytes = - isa(srcTy.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (isa(op)) { - auto value = op->getOperand(0); - // only scalar requires scratch memory - // make it explicit for readability - if (dyn_cast(value.getType())) { - // nothing to do - } else { - auto smemShape = getRepShapeForAtomic(op->getResult(0)); - auto elems = getNumScratchElements(smemShape); - auto elemTy = - cast(value.getType()).getPointeeType(); - auto bytes = - isa(elemTy) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } - } else if (auto callOp = dyn_cast(op)) { + constexpr size_t scratchAlignment = 128; + if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); auto *funcAlloc = &(*funcAllocMap)[funcOp]; auto bytes = funcAlloc->getSharedMemorySize(); maybeAddScratchBuffer(op, bytes, scratchAlignment); - } else if (auto createTensormap = - dyn_cast(op)) { - constexpr int32_t kTMASize = 128; - constexpr int32_t kTMAAlign = 128; - maybeAddScratchBuffer(op, kTMASize, - kTMAAlign); + return; } + unsigned bytes = scratchSizeGetter(op); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { @@ -346,7 +305,7 @@ class AllocationAnalysis { /// arguments are involved. void resolveAliasBufferLiveness( function_ref(Value value)> getLiveness) { - for (auto aliasBufferIter : allocation->aliasBuffer) { + for (const auto &aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; auto range = getLiveness(value); @@ -486,7 +445,7 @@ class AllocationAnalysis { std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { auto xRange = bufferRange[buffer]; bool res = xRange.intersects(range); - for (auto val : tripleMap) + for (const auto &val : tripleMap) res = res && !val.second.intersects(xRange); // only one buffer intersect return res; @@ -589,12 +548,16 @@ class AllocationAnalysis { Allocation::FuncAllocMapT *funcAllocMap; Allocation *allocation; BufferRangeMapT bufferRange; + AllocationAnalysisScratchSizeFn scratchSizeGetter; }; } // namespace triton -void Allocation::run(FuncAllocMapT &funcAllocMap) { - triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +void Allocation::run( + FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this, + scratchSizeGetter); } std::map> diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index f0c5ae3167ec..fc6a2c73befc 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1084,9 +1084,11 @@ LogicalResult AxisInfoAnalysis::visitOperation( void AxisInfoAnalysis::visitForOpInductionVar( scf::ForOp op, ArrayRef *> argLattices) { - ProgramPoint programPoint(op); - auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue(); - auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue(); + ProgramPoint *programPoint = getProgramPointAfter(op); + const auto &lb = + getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + const auto &step = + getLatticeElementFor(programPoint, op.getStep())->getValue(); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index a84f0649b623..693d222f2f39 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -7,7 +7,9 @@ add_triton_library(TritonAnalysis DEPENDS TritonTableGen + TritonGPUTableGen TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen LINK_LIBS PUBLIC MLIRAnalysis diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 7b2f7a4f6d05..3a8be9ee3347 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -69,18 +69,25 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { } unsigned threadOffset = 1; - if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { - auto parentLayout = sliceLayout.getParent(); - auto threadsPerWarp = getThreadsPerWarp(parentLayout); - threadOffset = threadsPerWarp[sliceLayout.getDim()]; - } else { - auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getThreadOrder(srcLayout); - for (unsigned i = 0; i < order.size(); i++) { - if (order[i] == axis) - break; - threadOffset *= threadsPerWarp[order[i]]; - } + SmallVector dimsRemoved; + while (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + dimsRemoved.push_back(sliceLayout.getDim()); + srcLayout = sliceLayout.getParent(); + } + // In case of slice layout we want to know the axis dimension relative to the + // most inner parent layout. `adjustedAxis` is the matching axis dim in the + // parent layout. + int adjustedAxis = axis; + for (auto dim : dimsRemoved) { + if (dim <= adjustedAxis) + adjustedAxis++; + } + auto threadsPerWarp = getThreadsPerWarp(srcLayout); + auto order = getThreadOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == adjustedAxis) + break; + threadOffset *= threadsPerWarp[order[i]]; } return threadOffset; } @@ -401,6 +408,17 @@ unsigned ScanLoweringHelper::getAxisBlockStride() { llvm_unreachable("Axis not found in order"); } +GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) + : gatherOp(gatherOp) {} + +unsigned GatherLoweringHelper::getScratchSizeInBytes() { + // For now, lower the gather op by writing the source tensor to shared memory. + // TODO(jeff): Leverage locality to avoid using scratch space when possible. + RankedTensorType srcType = gatherOp.getSrc().getType(); + return product(srcType.getShape()) * + ceil(srcType.getElementTypeBitWidth(), 8); +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; @@ -526,7 +544,8 @@ bool supportMMA(Value value, int version) { // types of both the operands are identical here. assert((version == 1 || version == 2 || version == 3) && "Unexpected MMA layout version found"); - auto elemTy = cast(value.getType()).getElementType(); + auto elemTy = + cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || @@ -536,7 +555,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -605,22 +624,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { return matrixDimsCompatible && bDimCompatible; } -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - if (mfmaLayout == nullptr || dotOperandLayout == nullptr) - return false; - // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is - // improved. In addition, we can enable this shortcut for regular MFMA - // layout when opIdx == 1. - return mfmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && - dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && - (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); -} - // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -636,73 +639,76 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, dotOperandLayout.getOpIdx() == 0 && mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && - (elementTypeSize == 16 || elementTypeSize == 8); + (elementTypeSize == 16 || elementTypeSize == 8) && + dotOperandLayout.getKWidth() == 32 / elementTypeSize; return ans; } -bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { +// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity +// under kBlock, kWarp or kLane (in that order). The idea here is that if we +// have a transformation that's the identity on kBlock, we don't need to use +// distributed shared memory. If it's also the identity on kWarp, we can +// transfer via warp-shuffles, and if it's the identity on kLane just have to +// reorder the registers +std::optional minimalCvtLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { MLIRContext *ctx = srcTy.getContext(); std::optional srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); std::optional dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (srcLayout.has_value() && dstLayout.has_value()) { - // comp describes the layout function for converting from src to dst. - LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - // TODO(jlebar): These checks are overly-restrictive. For example, we can - // transfer by shuffling registers (case 1) if and only if all of the bases - // for `register` have 0s for lane, warp, and block. But the check below is - // stronger than this, checking also that the choice of lane/warp/block does - // not affect the permutation of registers. If we allow different - // lane/warp/blocks to have different permutations, we can generalize this. - if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane), - kLane, kLane) * - LinearLayout::identity1D(comp.getInDimSize(kWarp), - kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), - kBlock, kBlock)) - .has_value()) { - return true; + if (!(srcLayout.has_value() && dstLayout.has_value())) + return std::nullopt; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + auto comp = dstLayout->invertAndCompose(*srcLayout); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { + auto quotient = comp.quotient(StringAttr::get(ctx, dim)); + if (!quotient.has_value()) { + break; } + comp = *quotient; } - return false; + return comp; +} + +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + if (!layout.has_value()) { + return false; + } + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(layout->getOutDimNames()); + return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); } bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); MLIRContext *ctx = srcTy.getContext(); - std::optional srcLayout = - toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (srcLayout.has_value() && dstLayout.has_value()) { - // comp describes the layout function for converting from src to dst. - LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp), - kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), - kBlock, kBlock)) - .has_value()) { - return true; - } + if (!layout.has_value()) { + return false; } - return false; + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + return llvm::to_vector(layout->getOutDimNames()) == + llvm::SmallVector{kRegister, kLane}; } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, - // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully - // subsumed by the linear-layout checks. + // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and + // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout + // checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !isMmaToDotShortcut(srcTy, dstTy) && - !isMfmaToDotShortcut(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { @@ -712,20 +718,6 @@ bool atomicNeedsSharedMemory(Value value) { return true; } -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) - return true; - // dot_op = #mma - // when #mma = MmaEncoding - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && - mmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getParent() == mmaLayout && - !srcTy.getElementType().isF32(); -} - namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c58b7fa0a347..e8ae340f2d81 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(Target) add_subdirectory(Tools) +add_subdirectory(Instrumentation) diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0ee49..0115383947d7 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -44,7 +44,7 @@ struct AllocateSharedMemory IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); }); - mod->setAttr("triton_gpu.shared", + mod->setAttr("ttg.shared", mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), allocation.getSharedMemorySize())); } diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index 20558c440add..7a3c8ce27abd 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -35,6 +35,14 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { } } llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + barrier(); + } rewriter.eraseOp(op); return success(); } @@ -42,7 +50,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { // know about the op to split the block. void llAssert(Operation *op, Value condition, StringRef message, ConversionPatternRewriter &rewriter) const { - ConversionPatternRewriter::InsertionGuard guard(rewriter); auto ctx = rewriter.getContext(); auto loc = op->getLoc(); @@ -79,6 +86,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { rewriter.create(loc, thenBlock); rewriter.setInsertionPointToEnd(prevBlock); rewriter.create(loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); } protected: diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 4d57131d029c..d6cc4387f79e 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp DotOpToLLVM/FMA.cpp + GlobalScratchMemoryAllocation.cpp TypeConverter.cpp Utility.cpp ElementwiseOpToLLVM.cpp @@ -12,6 +13,7 @@ add_triton_library(TritonGPUToLLVM AllocateSharedMemory.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp + GatherOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp ControlFlowOpToLLVM.cpp FuncOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 8d5a63eb1465..06e19029ebb8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -85,10 +85,21 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { if (!caller->hasAttr("allocation.offset")) { auto base = LLVM::getStackPointer(rewriter, caller); promotedOperands.push_back(base); - return promotedOperands; + } else { + auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); + promotedOperands.push_back(base); } - promotedOperands.push_back(LLVM::getSharedMemoryBase( - callOp->getLoc(), rewriter, targetInfo, callOp)); + + auto opOffsetAttr = caller->getAttrOfType( + "ttg.global_scratch_memory_offset"); + Value opOffsetVal; + if (opOffsetAttr) { + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + opOffsetVal = i32_val(opOffset); + } + + promotedOperands.push_back( + LLVM::getGlobalScratchPtr(loc, rewriter, caller, opOffsetVal)); return promotedOperands; } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 71d587d0d92d..7e8f6b783609 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -16,7 +16,6 @@ namespace { -using ::mlir::isLayoutMmaV1; using ::mlir::LLVM::getMultiDimOffset; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; @@ -56,8 +55,7 @@ struct ConvertLayoutOpConversion return isa( srcLayout) && isa( - dstLayout) && - !isLayoutMmaV1(srcLayout) && !isLayoutMmaV1(dstLayout); + dstLayout); } // shared memory rd/st for blocked or mma layout with data padding @@ -176,8 +174,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { @@ -282,83 +280,69 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const auto &shape = op.getType().getShape(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - std::optional srcLayout = - toLinearLayout(shape, srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(shape, dstTy.getEncoding()); - if (!srcLayout.has_value() || !dstLayout.has_value()) { - return failure(); - } - // There are four cases to handle. - // - // 1. Transfer between values in the same thread, in which case we simply - // reorder the elements of adaptor.getSrc(). - // 2. Transfer between values in the same warp, in which case we try to - // move values using warp shuffles, though if the pattern is complicated - // enough we may fall back to using shared memory (case 3). - // 3. Transfer between values in the same CTA, in which case we move values - // through shared memory. - // 4. Transfer between values in different CTAs, in which case we move - // values through distributed shared memory. - // - // We can tell which case we're in by examining `conversion`. - // For example, if the block -> block mapping is an identity layout: {1, 2, - // 4, ...}, then there's no movement between data in different CTAs, and we - // know we're not in case 4. - if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1. - return transferWithinThread(op, *srcLayout, *dstLayout, adaptor, - rewriter); + auto conversion = minimalCvtLayout(srcTy, dstTy); + if (!conversion.has_value()) { + return rewriter.notifyMatchFailure( + op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2. - return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter); - } + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); - return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor, - rewriter); // Case 3 and 4 + assert(to_vector(conversion->getInDimNames()) == + to_vector(conversion->getOutDimNames())); + auto dims = conversion->getInDimNames(); + if (llvm::is_contained(dims, kBlock)) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, kWarp)) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kLane)) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory + // TODO(Keren): implement warp shuffle instead of using the general + // approach that uses shared memory + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, *conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - // There are three possible cases: - // - // 1. `srcLayout` has the same number of registers as `dstLayout`. - // 2. `srcLayout` has fewer registers than `dstLayout`. - // 3. `srcLayout` has more registers than `dstLayout`. - // - // In the second case `srcLayout . dstLayout^-1` is not surjective - // because not all destination registers are covered. - // Since the goal is to cover all of the destination - // registers, we can instead use `dstLayout . srcLayout^-1`. - LinearLayout conversion = dstLayout.invertAndCompose(srcLayout); - auto dstToSrc = conversion.divideRight( - LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) * - LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock, - kBlock)); - assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) == - ArrayRef{kRegister}); - assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) == - ArrayRef{kRegister}); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals; - outVals.resize(dstToSrc->getInDimSize(kRegister)); - for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) { - auto srcIdx = dstToSrc->apply({{kRegister, i}}); - outVals[i] = inVals[srcIdx.begin()->second]; + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); @@ -366,61 +350,32 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } - LogicalResult transferWithinLane(ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // TODO(Keren): implement warp shuffle instead of using the general approach - // that uses shared memory - return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor, - rewriter); - } - - LogicalResult - transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - LinearLayout conversion = srcLayout.invertAndCompose(dstLayout); - - // TODO(Keren): LLs support cross-CTA conversions, this function does not - if (isCrossCTAConversion(conversion)) - return failure(); - + LogicalResult transferWithinBlock(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - // TODO (Keren): Currently, we handle general mma/blocked/slice -> - // mma/blocked/slice conversions. - // The following tasks must be completed before we can remove the layoutIsOK - // check: - // 1. Support for AMD's MFMA and WMMA + // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) + // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be + // completed before we can remove the layoutIsOK check: + // 1. Support for AMD's WMMA dot operand std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = dyn_cast(layout)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + if (isa(layout)) { + return !useLegacyMMAConversion; } if (auto dotOperand = dyn_cast(layout)) { - if (auto nvidiaMma = - dyn_cast(dotOperand.getParent())) { - if (product(getCTAsPerCGA(nvidiaMma)) > 1) { - return false; - } - if (useLegacyMMAConversion) { - return false; - } - // FIXME [Dot LL] - // Enabling LL path for buggy kWidth path - bool largeKWidth = - dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; - return largeKWidth && nvidiaMma.isAmpere(); + if (isa( + dotOperand.getParent())) { + return !useLegacyMMAConversion; } + return false; } - if (isa(layout)) { + if (isa(layout)) { return true; } if (auto slice = dyn_cast(layout)) { @@ -431,6 +386,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) { return failure(); } + // FIXME [Dot LL] Remove this once we implement this trick in LLs + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { + return failure(); + } assert(cvtNeedsSharedMemory(srcTy, dstTy)); @@ -461,11 +420,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } + // Pretty sure this is the identity function ATM + // It'd be better to simply call `quotient({kBlock})` and + // remove kBlock from transferWithinBlockImpl auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout); auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout); SmallVector outVals = - transferWithinBlock(inVals, op, srcLayoutWithinBlock, - dstLayoutWithinBlock, adaptor, rewriter); + transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock, + dstLayoutWithinBlock, adaptor, rewriter); // Unmunge output values for (const auto &it : llvm::enumerate(outVals)) { @@ -476,22 +438,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - // FIXME [Dot LL] - // We know it's just for largeKWidth case in Ampere - // In this case, we need to pack the outputs into i32 - if (isa(dstTy.getEncoding())) { - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -499,10 +445,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } SmallVector - transferWithinBlock(ArrayRef inVals, ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + transferWithinBlockImpl(ArrayRef inVals, ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); @@ -534,19 +480,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // don't need to avoid duplicate writes. // Input dims: [reg, lane, warp] // Output dims: [offset, iteration] - std::optional shmemStoreLayout = - chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape, - scratchConfig.paddedRepShape, scratchConfig.order, - /*swizzleByteSize=*/0); - bool isStMatrix = shmemStoreLayout.has_value(); - if (!isStMatrix) { - shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout); - } - assert(shmemStoreLayout.has_value()); + bool isStMatrix = targetInfo.canUseStMatrix( + op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); + LinearLayout shmemStoreLayout = + isStMatrix ? chooseStMatrixLayout( + ctx, op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0) + : srcLayout.invertAndCompose(sharedLayout); const int shmemAllocatedNumElems = getNumScratchElements(scratchConfig.paddedRepShape); - assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems); + assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems); // Layout for the load from shmem to registers. LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); @@ -554,14 +501,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Check that the `register` fully determines the `iteration`. That is, // each thread does exactly the same reads and writes to shmem on each // iteration, just with different input/output registers. - assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock}, - {kIteration})); + assert( + shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); assert( shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); // iteration -> registers SmallVector> inRegsForIter = - collectRegsForIter(ctx, *shmemStoreLayout); + collectRegsForIter(ctx, shmemStoreLayout); SmallVector> outRegsForIter = collectRegsForIter(ctx, shmemLoadLayout); @@ -618,7 +565,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return vecAddr; }; - auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout, + auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, {{kRegister, i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, @@ -641,11 +588,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // When using `stmatrix`, we can store `inVec` elements even if they are // not contiguous - auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut() + auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut() : scratchConfig.inVec; for (int j = 0; j < inVals.size() / iterations; j += inVec) { auto inRegSlice = inRegs[j]; - Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice); + Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice); SmallVector inValsVec; for (int k = 0; k < inVec; k++) inValsVec.push_back(inVals[inRegSlice + k]); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index be2e6f584f1c..4914fd712b87 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -12,6 +12,7 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; SmallVector diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1346cc143ed2..d5afb6e2b188 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -90,6 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + auto dotParent = dyn_cast(dstDotOp.getParent()); + if (dotParent) { + return; + } Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8762942c311c..5869ab36f08b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -17,134 +17,17 @@ Type getElementType(Value value) { return tensorType.getElementType(); return type; } -// MMA encoding has a different order depending on the element's bit width; -// reorder if we're in this case. -SmallVector reorderValues(const SmallVector &values, Type inType, - Type ouType) { - auto inTensorTy = dyn_cast(inType); - auto ouTensorTy = dyn_cast(ouType); - if (!inTensorTy || !ouTensorTy) - return values; - auto inEncoding = dyn_cast(inTensorTy.getEncoding()); - auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); - assert(inEncoding == ouEncoding); - if (!inEncoding) - return values; - // If the parent of the dot operand is in block encoding, we don't need to - // reorder elements - auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding) - return values; - size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); - size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); - auto ouEltTy = ouTensorTy.getElementType(); - if (inBitWidth == ouBitWidth) - return values; - if (inBitWidth == 16 && ouBitWidth == 32) { - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 8) { - ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 7]); - } - return ret; - } - if (inBitWidth == 8 && ouBitWidth == 16) { - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i + 0]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 7]); - ret.push_back(values[i + 12]); - ret.push_back(values[i + 13]); - ret.push_back(values[i + 14]); - ret.push_back(values[i + 15]); - } - return ret; - } - llvm_unreachable("unimplemented code path"); -} - -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - for (auto v : inValues) { - // cast i32 to appropriate eltType vector and extract elements - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecType); - for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); - auto vecType = vec_ty(eltType, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecType); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return numElemsPerThread; - auto structType = - dyn_cast(typeConverter->convertType(type)); - if (structType) { - numElemsPerThread = structType.getBody().size(); + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); } - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return numElemsPerThread; - auto eltType = tensorTy.getElementType(); - assert(eltType.getIntOrFloatBitWidth() <= 32 && - "Only support element type with bit width <= 32 in dot operand mma " - "layout"); - // dot operand data are packed into i32 elements so use the following formula - // to get the number of elements per thread. - return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; + return numElemsPerThread; } } // namespace mlir::triton::gpu @@ -442,13 +325,12 @@ struct ElementwiseInlineAsmOpConversion // asmResults is a flat struct; pack its values into // [return_value][op.getPackedElement()]. SmallVector> ret(op->getNumResults()); + int structIdx = 0; for (int i = 0; i < op->getNumResults(); i++) { - int structIdx = 0; for (int j = 0; j < op.getPackedElement(); j++) { Value val; if (asmRetTypes.size() > 1) { - val = - extract_val(asmResults, i * op.getPackedElement() + structIdx++); + val = extract_val(asmResults, structIdx++); } else { val = asmResults; } @@ -475,8 +357,7 @@ struct ElementwiseInlineAsmOpConversion for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - unpackedOperands.push_back( - unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackedOperands.push_back(subOperands); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -529,17 +410,8 @@ struct ElementwiseInlineAsmOpConversion // Reorder and pack the results. SmallVector outs; for (int i = 0; i < unpackedResults.size(); i++) { - // We reordered all the inputs so they match operand 0. Reorder the - // outputs accordingly. - if (op->getNumOperands() > 0) { - unpackedResults[i] = reorderValues( - unpackedResults[i], /*inType=*/op->getOperand(0).getType(), - /*ouType=*/op->getResult(i).getType()); - } - auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), - rewriter, loc, getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, - op->getResult(i).getType())); + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 8ffa9517e5a5..bee29c8f39d9 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "mlir/IR/BuiltinAttributes.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -13,6 +14,19 @@ namespace { using namespace mlir; using namespace mlir::triton; +// NOTE: [Additional Function Arguments] +// To support use of shared memory and global scratch memory inside of a +// function, the caller allocates a single large block of the relevant memory +// and calls the funciton with these extra arguments at the end. +// Specifically, the last argument is the global scratch memory allocation and +// the second to last is the shared memory allocation. +// +// For the kernel function itself, the shared memory base is a global symbol +// so no additional function argument is required but global scratch memory +// allocation is still passed in as the last argument. Though here the scratch +// memory is shared between all programs, so a linear offset based on the +// program id is required to get the local scratch base. + /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. @@ -41,30 +55,46 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { triton::FuncOp amendFuncOp(triton::FuncOp funcOp, ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) const { - // Push back a variable that indicates the current stack pointer of shared - // memory to the function arguments. + // Push back two new arguments that indicate the current pointer to shared + // memory and global scratch memory. auto loc = funcOp.getLoc(); auto ctx = funcOp->getContext(); - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), - targetInfo.getSharedAddressSpace()); - // 1. Modify the function type to add the new argument. + auto sharedPtrTy = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); + auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1); + + // 1. Modify the function type to add the new arguments. auto funcTy = funcOp.getFunctionType(); auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); - amendedInputTy.push_back(ptrTy); - auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, - funcTy.getResults()); + bool isKernel = LLVM::isKernel(funcOp); + if (!isKernel) { + amendedInputTy.push_back(sharedPtrTy); + } + amendedInputTy.push_back(globalPtrTy); + auto amendedFuncTy = + FunctionType::get(ctx, amendedInputTy, funcTy.getResults()); // 2. Modify the argument attributes to add the new argument. SmallVector amendedAttrs; filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); - auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); - // 3. Add a new argument to the region + if (auto argAttrs = funcOp.getAllArgAttrs()) { + llvm::SmallVector amendedArgAttrs(argAttrs.begin(), + argAttrs.end()); + while (amendedArgAttrs.size() < amendedInputTy.size()) { + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + } + amendedAttrs.push_back( + rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(), + rewriter.getArrayAttr(amendedArgAttrs))); + } + + // 3. Add the new arguments to the region auto amendedFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); auto ®ion = funcOp.getBody(); - region.addArgument(ptrTy, loc); + if (!isKernel) { + region.addArgument(sharedPtrTy, loc); + } + region.addArgument(globalPtrTy, loc); rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), amendedFuncOp.end()); return amendedFuncOp; @@ -110,9 +140,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prevent LLVM's inliner to inline this function - auto amendedFuncOp = funcOp; - if (!LLVM::isKernel(funcOp)) - amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); + auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); FailureOr maybeNewFuncOp = mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, @@ -136,14 +164,15 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 newFuncOp.setPassthroughAttr( ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); - rewriter.eraseOp(amendedFuncOp); newFuncOp.setLinkage(LLVM::Linkage::Internal); } // Set an attribute for reqntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); + rewriter.eraseOp(funcOp); + rewriter.eraseOp(amendedFuncOp); // Add attributes for by-value TMA descriptor args (nvidia) handleByvalTmaDescArgs(newFuncOp); diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp new file mode 100644 index 000000000000..5ab81eff819c --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -0,0 +1,112 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +class GatherOpConversion : public ConvertOpToLLVMPattern { +public: + GatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + const TargetInfoBase &targetInfo; +}; + +LogicalResult +GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + + // Compute the src subtensor shape owned by this CTA. + SmallVector srcShapePerCTA = + convertType(triton::gpu::getShapePerCTA(srcType)); + + // Grab the src values in this thread. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // Emit the indices of the src values owned by this thread. + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(), + op.getSrc().getType(), /*withCTAOffset=*/true); + + // Store the src values owned by the thread into their respective location in + // the scratch memory. + assert(srcValues.size() == srcIndices.size()); + + // Get the base pointer to the scratch memory. + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // For each src element owned by the thread, index into the scratch memory and + // then store it. + Type elemType = getTypeConverter()->convertType(srcType.getElementType()); + for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) { + // Convert the index at each dim into a single offset given the shape of the + // tensor. + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + // Emit the offset into the shared memory and then store the value. + Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); + store(value, ptr); + } + + // Synchronize the whole CTA. + barrier(); + + // Grab the index values owned by this thread. + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // Apply the layout of the destination tensor to obtain the indices of the + // column to gather along, then for each column, replace the index along the + // gather axis with the appropriate index value. + // + // I = LL(pid) + // idx = indices[I] + // I_gather = [I[d] if d != axis else idx for d in range(len(I))] + // out[I] = src[I_gather] + RankedTensorType dstType = op.getType(); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, + /*withCTAOffset=*/true); + + unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); + unsigned axis = op.getAxis(); + SmallVector results(dstIndices.size()); + for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + idx = trunc(i32_ty, idx); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + idx = zext(i32_ty, idx); + } + indices[axis] = idx; + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); + results[i] = load(elemType, ptr); + } + + Value packed = + packLLElements(loc, getTypeConverter(), results, rewriter, dstType); + rewriter.replaceOp(op, packed); + return success(); +} + +} // namespace + +void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.insert(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp b/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp new file mode 100644 index 000000000000..3fcaf4197c6c --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp @@ -0,0 +1,108 @@ +#include "mlir/Analysis/Liveness.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +namespace { + +static int32_t roundUp(int32_t val, int32_t step) { + auto t = val + step - 1; + return t - (t % step); +} + +static void allocateGMem(Operation *parentOp, + llvm::SetVector &callStack) { + // Recursively visit any dependency functions + parentOp->walk([&](triton::CallOp call) { + auto callable = call.resolveCallable(); + if (!callable->hasAttr("ttg.global_scratch_memory_size")) { + auto inserted = callStack.insert(parentOp); + assert(inserted && "call cycle detected"); + allocateGMem(callable, callStack); + callStack.remove(parentOp); + } + }); + + MLIRContext *ctx = parentOp->getContext(); + OpBuilder builder(ctx); + int32_t offset = 0; + uint32_t largestAlignment = 1; + + // Dumb allocation that ignores liveness and makes no attempt to minimize + // padding + // TODO: Use a real algorithm + parentOp->walk([&](Operation *op) { + uint32_t nbytes = 0; + uint32_t align = 0; + if (auto alloc = dyn_cast(op)) { + nbytes = alloc.getNbytes(); + align = alloc.getAlignment(); + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto nbytes_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_size"); + auto align_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(nbytes_attr); + assert(align_attr); + + nbytes = nbytes_attr.getValue().getZExtValue(); + align = align_attr.getValue().getZExtValue(); + } + if (nbytes > 0) { + offset = roundUp(offset, align); + op->setAttr("ttg.global_scratch_memory_offset", + builder.getI32IntegerAttr(offset)); + offset += nbytes; + largestAlignment = std::max(largestAlignment, align); + } + }); + int32_t totalMemorySize = roundUp(offset, largestAlignment); + parentOp->setAttr("ttg.global_scratch_memory_size", + builder.getI32IntegerAttr(totalMemorySize)); + parentOp->setAttr("ttg.global_scratch_memory_alignment", + builder.getI32IntegerAttr(largestAlignment)); +} + +class TritonGPUGlobalScratchAllocationPass + : public TritonGPUGlobalScratchAllocationPassBase< + TritonGPUGlobalScratchAllocationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + + bool seenKernel = false; + + SetVector callStack; + mod->walk([&](triton::FuncOp func) { + allocateGMem(func, callStack); + + if (func.getVisibility() == SymbolTable::Visibility::Public) { + assert(!seenKernel); + seenKernel = true; + auto size = + func->getAttrOfType("ttg.global_scratch_memory_size"); + auto align = func->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(size); + assert(align); + mod->setAttr("ttg.global_scratch_memory_size", size); + mod->setAttr("ttg.global_scratch_memory_alignment", align); + } + }); + assert(seenKernel); + } +}; + +} // namespace + +std::unique_ptr +mlir::triton::gpu::createTritonGPUGlobalScratchAllocationPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 0ccd97970a55..1e6e1c1fd717 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -15,27 +15,51 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); - assert(srcTy.getShape().size() <= 2 || - (srcTy.getShape().size() == 3 && outOrd[2] == 0) && - "Unexpected rank of ConvertLayout(blocked->shared)"); auto elemTy = typeConverter->convertType(srcTy.getElementType()); auto smemBase = smemObj.getBase(); auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo); + loc, rewriter, targetInfo, llvmOpCount); } +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + GlobalScratchAllocOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto opOffsetAttr = op->getAttrOfType( + "ttg.global_scratch_memory_offset"); + assert(opOffsetAttr); + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) { + return failure(); + } + Value ptr = + LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, i32_val(opOffset)); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + struct LocalAllocOpConversion : public ConvertOpToLLVMPattern { LocalAllocOpConversion(const LLVMTypeConverter &converter, @@ -109,27 +133,54 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { } + // FIXME [Dot LL] + // Do for all DotOperandEncodingAttr once we have LLs for all of them + static bool isSupportedDotOpLayout(MemDescType srcTy, + RankedTensorType dstTy) { + auto srcLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto bitwidth = dstTy.getElementTypeBitWidth(); + auto rank = dstTy.getRank(); + if (auto dot = dyn_cast(dstLayout)) { + auto vecWidth = 32 / bitwidth; + auto kWidth = dot.getKWidth(); + auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; + if (auto mma = dyn_cast(dot.getParent())) { + auto needTrans = kOrder != srcLayout.getOrder()[0]; + auto canUseLdmatrix = + (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); + if (mma.isHopper()) { + // I think we should be able to remove this condition, but it's here + // as the legacy ldmatrix path does not support it + canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32; + } + // If we remove this one, ldmatrix will IMA. It can probably be relaxed + // though + canUseLdmatrix &= + srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; + // To be removed in https://github.com/triton-lang/triton/pull/5154 + bool legacyLoweringIsBuggy = + (kWidth >= 8 || (kWidth == 4 && bitwidth == 32) || + dstTy.getRank() == 3) && + mma.isAmpere(); + return (mma.isHopper() && !canUseLdmatrix) || + (mma.isAmpere() && legacyLoweringIsBuggy); + } + if (isa(dot.getParent())) + return true; + } + return false; + }; + LogicalResult matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - // FIXME [Dot LL] - // Do for all DotOperandEncodingAttr once we have LLs for all of them - auto isAmpereLargeKWidth = [](Attribute layout) { - if (auto dot = dyn_cast(layout)) { - if (auto mma = dyn_cast(dot.getParent())) { - return mma.isAmpere() && dot.getKWidth() == 8; - } - } - return false; - }; - if (isa(srcLayout) && - (isa( - dstLayout) || - isAmpereLargeKWidth(dstLayout))) { + if (isa(dstLayout) || + isSupportedDotOpLayout(srcTy, dstTy)) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -167,11 +218,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); - assert(dstShape.size() <= 2 && - "Unexpected rank of ConvertLayout(shared->blocked)"); auto srcSharedLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); - auto inOrd = getOrder(srcSharedLayout); + assert((!isa(dstTy.getEncoding()) || + isSupportedDotOpLayout(srcTy, dstTy)) && + "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), @@ -181,37 +231,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - // FIXME [Dot LL] - // Ampere case - // In this case, we need to pack the outputs into i32 - if (isa(dstTy.getEncoding())) { - if (elemLlvmTy.isInteger(8)) { - auto concat = [&](Value a1, Value a2, Value a3, Value a4) { - return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), - or_(shl(zext(i32_ty, a3), i32_val(16)), - shl(zext(i32_ty, a4), i32_val(24)))); - }; - SmallVector outVals32(outVals.size() / 4); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], - outVals[4 * i + 2], outVals[4 * i + 3]); - } - outVals = outVals32; - } else { - assert(elemLlvmTy.isBF16() && "Unexpected element type"); - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - } - Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); @@ -227,12 +246,15 @@ struct LocalStoreOpConversion public: using ConvertOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -242,24 +264,37 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { + patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 8682706db899..088dbd997602 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -162,13 +162,6 @@ struct ReduceOpConversion auto mod = op->getParentOfType(); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - if (iWarpSize > numLaneToReduce) { - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(iWarpSize); - Value laneId = urem(threadId, warpSize); - Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); - pred = pred ? and_(pred, lanePred) : lanePred; - } for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); @@ -225,6 +218,46 @@ struct ReduceOpConversion rewriter.replaceOp(op, results); } + // For slice layout some ids are duplicated on multiple lanes, so we need to + // handle the delinearization of laneId in a special way. We need to + // generalize this part of the logic to work on any kind of linear layout + // uniformely. + SmallVector + getMultiDimLaneId(ReduceOpHelper &helper, Value &laneId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = triton::gpu::getThreadOrder(srcLayout); + SmallVector multiDimLaneId; + + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + SmallVector dims = {sliceLayout.getDim()}; + while (auto parentSliceLayout = + mlir::dyn_cast(parentLayout)) { + dims.push_back(parentSliceLayout.getDim()); + parentLayout = parentSliceLayout.getParent(); + } + + auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp(parentLayout); + auto parentOrder = triton::gpu::getThreadOrder(parentLayout); + multiDimLaneId = delinearize(rewriter, loc, laneId, parentThreadsPerWarps, + parentOrder); + for (unsigned dim : llvm::reverse(dims)) { + multiDimLaneId.erase(multiDimLaneId.begin() + dim); + } + } else { + SmallVector threadsPerWarps = + triton::gpu::getThreadsPerWarp(srcLayout); + threadsPerWarps[helper.getAxis()] = + triton::gpu::getThreadsPerWarpWithUniqueData( + srcLayout, srcShape)[helper.getAxis()]; + multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarps, order); + } + return multiDimLaneId; + } + SmallVector getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, ConversionPatternRewriter &rewriter) const { @@ -238,11 +271,20 @@ struct ReduceOpConversion // a way to properly delinearize warpId in the slice case if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { auto parentLayout = sliceLayout.getParent(); + SmallVector dims = {sliceLayout.getDim()}; + while (auto parentSliceLayout = + mlir::dyn_cast(parentLayout)) { + dims.push_back(parentSliceLayout.getDim()); + parentLayout = parentSliceLayout.getParent(); + } + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); auto parentOrder = triton::gpu::getWarpOrder(parentLayout); multiDimWarpId = delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); - multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + for (unsigned dim : llvm::reverse(dims)) { + multiDimWarpId.erase(multiDimWarpId.begin() + dim); + } } else { SmallVector warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); @@ -270,11 +312,8 @@ struct ReduceOpConversion unsigned axis = op.getAxis(); auto smemShape = helper.getScratchRepShape(); - auto threadsPerWarp = - triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + getMultiDimLaneId(helper, laneId, loc, rewriter); Value laneIdAxis = multiDimLaneId[axis]; Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); @@ -382,7 +421,7 @@ struct ReduceOpConversion auto resultIndices = emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + auto resultCTATile = getShapePerCTATile(resultLayout); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index c69a6b32425e..64e6ca787780 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -187,7 +187,7 @@ static void AddPartialReduce(SmallVector> &srcValues, } Value mask = icmp_sge(warpId, i32_val(i + 1)); accumulator.acc = - accumulate(helper, rewriter, accumulator.acc, partialReduce, mask); + accumulate(helper, rewriter, accumulator.acc, partialReduce); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - auto order = triton::gpu::getOrder(srcEncoding); + auto threadOrder = triton::gpu::getThreadOrder(srcEncoding); auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, multiDimLaneId[axis] = i32_val(0); threadsPerWarp[axis] = 1; Value laneIdParallel = - linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder); multiDimWarpId[axis] = i32_val(0); warpsPerCTA[axis] = 1; Value warpIdParallel = diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index cc6d8875b5c7..1ed1f4063585 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -11,6 +11,7 @@ using namespace mlir::triton; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -28,6 +29,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](MemDescType type) -> std::optional { return convertMemDescType(type, targetInfo); }); + addConversion([](TensorDescType type) -> std::optional { + auto ctx = type.getContext(); + return LLVM::LLVMPointerType::get(ctx, 1); + }); addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { return convertAsyncToken(type); }); @@ -70,29 +75,12 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); } -Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( - TensorOrMemDesc type) { - auto ctx = type.getContext(); - Attribute layout = type.getEncoding(); - Type elemTy = convertType(type.getElementType()); - auto dotOpLayout = mlir::dyn_cast(layout); - if (!dotOpLayout) - return elemTy; - auto mmaParent = - mlir::dyn_cast(dotOpLayout.getParent()); - if (!mmaParent || mmaParent.isHopper()) - return elemTy; - int bitwidth = elemTy.getIntOrFloatBitWidth(); - assert(bitwidth <= 32); - return IntegerType::get(ctx, 32); -} - Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( RankedTensorType type, const TargetInfoBase &targetInfo) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); SmallVector shape(type.getShape().begin(), type.getShape().end()); - Type eltType = getElementTypeForStruct(cast(type)); + Type eltType = convertType(type.getElementType()); if (auto shared_layout = mlir::dyn_cast(layout)) { SmallVector types; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6cb..a310cdba5f4f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -4,113 +4,9 @@ #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" -namespace SharedToDotOperandMMAv1 { -using CoordTy = SmallVector; -using ValueTable = std::map, std::pair>; - -static SmallVector -getMNCoords(Value thread, Location loc, RewriterBase &rewriter, - ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, - ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, - bool isBVec4) { - static constexpr std::array fpw{{2, 2, 1}}; - - auto *ctx = thread.getContext(); - Value _1 = i32_val(1); - Value _2 = i32_val(2); - Value _4 = i32_val(4); - Value _16 = i32_val(16); - Value _32 = i32_val(32); - Value _fpw0 = i32_val(fpw[0]); - Value _fpw1 = i32_val(fpw[1]); - - // A info - auto aRep = mmaLayout.getMMAv1Rep(0); - auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); - // B info - auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); - auto bRep = mmaLayout.getMMAv1Rep(1); - - SmallVector rep({aRep[0], bRep[1]}); - SmallVector spw({aSpw[0], bSpw[1]}); - SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); - - Value lane = urem(thread, _32); - Value warp = udiv(thread, _32); - - Value warp0 = urem(warp, i32_val(wpt[0])); - Value warp12 = udiv(warp, i32_val(wpt[0])); - Value warp1 = urem(warp12, i32_val(wpt[1])); - - // warp offset - Value offWarpM = mul(warp0, i32_val(spw[0])); - Value offWarpN = mul(warp1, i32_val(spw[1])); - // quad offset - Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); - Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); - // pair offset - Value offPairM = udiv(urem(lane, _16), _4); - offPairM = urem(offPairM, _fpw0); - offPairM = mul(offPairM, _4); - Value offPairN = udiv(urem(lane, _16), _4); - offPairN = udiv(offPairN, _fpw0); - offPairN = urem(offPairN, _fpw1); - offPairN = mul(offPairN, _4); - - // sclare - offPairM = mul(offPairM, i32_val(rep[0] / 2)); - offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); - offPairN = mul(offPairN, i32_val(rep[1] / 2)); - offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); - - // quad pair offset - Value offLaneM = add(offPairM, offQuadM); - Value offLaneN = add(offPairN, offQuadN); - // a, b offset - Value offsetAM = add(offWarpM, offLaneM); - Value offsetBN = add(offWarpN, offLaneN); - // m indices - Value offsetCM = add(and_(lane, _1), offsetAM); - SmallVector idxM; - for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) - for (unsigned mm = 0; mm < rep[0]; ++mm) - idxM.push_back(add(offsetCM, i32_val(m + mm * 2))); - - // n indices - Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); - SmallVector idxN; - for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { - for (int nn = 0; nn < rep[1]; ++nn) { - idxN.push_back(add( - offsetCN, i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]))); - idxN.push_back( - add(offsetCN, - i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1))); - } - } - - SmallVector> axes({idxM, idxN}); - - // product the axis M and axis N to get coords, ported from - // generator::init_idx method from triton2.0 - - // TODO[Superjomn]: check the order. - SmallVector coords; - for (Value x1 : axes[1]) { // N - for (Value x0 : axes[0]) { // M - SmallVector idx(2); - idx[0] = x0; // M - idx[1] = x1; // N - coords.push_back(std::move(idx)); - } - } - - return coords; // {M,N} in row-major -} -} // namespace SharedToDotOperandMMAv1 - namespace mlir { namespace triton::gpu { @@ -203,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, return outIndices; } +std::tuple emitHardwareTuple(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + bool withCTAOffset, + unsigned threadsPerWarpCst) { + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(threadsPerWarpCst); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + return {blockId, warpId, laneId}; +} + SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { @@ -220,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - Value blockId = - withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; // Linear layout function is split in two parts below: @@ -264,8 +170,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, } bool emitTransferBetweenRegistersAndShared( - RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, Value shmemBase, + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, + Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { @@ -279,41 +185,17 @@ bool emitTransferBetweenRegistersAndShared( StringAttr kLane = str_attr("lane"); StringAttr kWarp = str_attr("warp"); - std::optional regLayout = - triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); - std::optional sharedLayout = triton::gpu::toLinearLayout( - shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); - if (!regLayout.has_value() || !sharedLayout.has_value()) { + auto regToSharedLayout = getRegToSharedLayout( + ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(), + elemLlvmTy.getIntOrFloatBitWidth()); + if (!regToSharedLayout.has_value()) return false; - } - auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); - - // sharedLayout's in-dims are currently (offset, block). Reshape to - // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional - // shmem strides. (The offsetX's appear in minor-to-major order.) - auto sharedLegacy = - cast(sharedTy.getEncoding()); - SmallVector> multiDimSharedSize; - for (int i = 0; i < rank; i++) { - int dim = sharedOrder[i]; - int64_t size = std::max( - int64_t{1}, - shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); - multiDimSharedSize.push_back( - {str_attr("offset" + std::to_string(dim)), size}); - } - multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); - sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); - - // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, - // ..., offsetXN, block), where the offsetX's are in minor-to-major order. - LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout); // TODO(jlebar): We don't currently support loading from shared memory in a // different CTA. We'd need to emit `mapa.shared::cluster` instructions. - for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); + for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock); inBlock *= 2) { - auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply( + auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply( {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}))); // offsetX1, ..., offsetXN must all be 0. if (!llvm::all_of(ArrayRef(idx).drop_back(1), @@ -339,15 +221,14 @@ bool emitTransferBetweenRegistersAndShared( // which have known strides. This would allow us to vectorize across multiple // shmem out dimensions where possible. const int vecElems = - std::min(regToSharedLayout.getNumConsecutiveInOut(), + std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, + regToSharedLayout->getInDimSize(kLane)); - int numElems = regToSharedLayout.getInDimSize(kRegister); + int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); auto ptrTy = shmemBase.getType(); Value zero = i32_val(0); @@ -358,7 +239,7 @@ bool emitTransferBetweenRegistersAndShared( // we drop_end to drop block, which we know from above will be 0. auto multiDimShmemOffset = llvm::to_vector(llvm::drop_end(llvm::make_second_range( - applyLinearLayout(loc, rewriter, regToSharedLayout, + applyLinearLayout(loc, rewriter, *regToSharedLayout, {{kRegister, i32_val(i * vecElems)}, {kLane, laneId}, {kWarp, warpId}, @@ -366,6 +247,7 @@ bool emitTransferBetweenRegistersAndShared( // Reorder strides according to `order`. This way they match the // multi-dimensional offsets in regToSharedLayout. + auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, applyPermutation(shmemStrides, sharedOrder)); auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); @@ -377,7 +259,8 @@ bool emitTransferBetweenRegistersAndShared( } SmallVector loadSharedToDistributed(RankedTensorType dstTy, - MemDescType srcTy, Type elemLlvmTy, + triton::gpu::MemDescType srcTy, + Type elemLlvmTy, SharedMemoryObject smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { @@ -400,11 +283,13 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, return ret; } -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target) { +void storeDistributedToShared(triton::gpu::MemDescType dstTy, + RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, + ArrayRef dstStrides, Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { @@ -418,7 +303,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } @@ -744,10 +634,8 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value laneId = urem(threadId, warpSize); - Value warpId = udiv(threadId, warpSize); + auto [blockId, warpId, laneId] = emitHardwareTuple( + loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); @@ -779,8 +667,6 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1])); mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); - } else if (mmaLayout.isVolta()) { - // Volta doesn't follow the pattern here. } else { llvm_unreachable("Unexpected MMALayout version"); } @@ -809,13 +695,6 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, multiDimOffset[rank - 1] = add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] * shapePerCTATile[rank - 1])); - } else if (mmaLayout.isVolta()) { - auto [isARow, isBRow, isAVec4, isBVec4, _] = - mmaLayout.decodeVoltaLayoutStates(); - auto coords = SharedToDotOperandMMAv1::getMNCoords( - threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape, - isARow, isBRow, isAVec4, isBVec4); - return coords[elemId]; } else { llvm_unreachable("Unexpected MMALayout version"); } @@ -856,5 +735,49 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values) { + SmallVector results; + for (auto v : values) { + auto em0 = and_(v, i8_val(0x7)); + auto em1 = and_(v, i8_val(0x70)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0), + bf16_ty); + v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1), + bf16_ty); + // 3) x is zero, nothing to do + results.push_back(v0); + results.push_back(v1); + } + return results; +} + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value vBf16 = bitcast(v, bf16_ty); + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(vBf16, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; + } // namespace LLVM } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 297a94e851f6..ea05490c7a0a 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -1,7 +1,8 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" using namespace mlir; using namespace mlir::triton; @@ -269,27 +270,38 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { return success(); } }; +struct MemDescTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto enc = cast(resultTy.getEncoding()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.base, srcSmemObj.baseElemType, + /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), + /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + struct TransOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTy = cast(op.getType()); - if (auto enc = dyn_cast(resultTy.getEncoding())) { - auto llvmElemTy = - getTypeConverter()->convertType(resultTy.getElementType()); - auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - auto dstSmemObj = SharedMemoryObject( - srcSmemObj.base, srcSmemObj.baseElemType, - /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), - /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); - auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); - } else if (auto enc = mlir::dyn_cast( - resultTy.getEncoding())) { + auto resultTy = cast(op.getType()); + if (auto enc = + mlir::dyn_cast(resultTy.getEncoding())) { // If the dst encoding is blocked, then TransOp::inferReturnTypes // ensures that: // - the src encoding is also blocked, and @@ -302,9 +314,10 @@ struct TransOpConversion : public ConvertOpToLLVMPattern { rewriter.replaceOp(op, ret); return success(); } - return emitOptionalError(loc, "unsupported encoding for TransOp"); + return emitOptionalError(loc, "unsupported encoding for MemDescTransOp"); } }; + struct BroadcastOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -336,7 +349,6 @@ struct BroadcastOpConversion unsigned rank = srcTy.getRank(); auto typeConverter = getTypeConverter(); assert(rank == resultTy.getRank()); - auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); SmallVector srcVals = unpackLLElements(loc, src, rewriter); @@ -408,6 +420,7 @@ void mlir::triton::populateViewOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 1b629ba1639f..51ca93025710 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -3,7 +3,7 @@ add_triton_library(TritonToTritonGPU TritonToTritonGPUPass.cpp DEPENDS - TritonConversionPassIncGen + TritonConversionToGPUPassIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 34fb8995430f..06e75ee18d59 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -56,20 +56,19 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); - return std::nullopt; + return {}; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); - return std::nullopt; + return {}; }); // This will be called when (desiredType != newOperandType) @@ -79,7 +78,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, ValueRange inputs, Location loc) { auto cast = builder.create(loc, tensorType, inputs); - return std::optional(cast.getResult()); + return cast.getResult(); }); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index bd17e2d7c8b2..464b150dc1a2 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -2,12 +2,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -59,15 +56,10 @@ class ArithConstantPattern : public OpConversionPattern { Type retType = getTypeConverter()->convertType(op.getType()); auto retShapedType = cast(retType); auto value = dyn_cast(adaptor.getValue()); - if (dyn_cast(retShapedType)) { - assert(value); - if (value.getElementType().isInteger(1) && value.isSplat()) - // Workaround until https://reviews.llvm.org/D133743 is included. - value = - DenseElementsAttr::get(retShapedType, value.getSplatValue()); - else - // This is a hack. We just want to add encoding - value = value.reshape(retShapedType); + if (isa(retShapedType)) { + assert(value && "expected a dense elements attribute"); + // This is a hack. We just want to add encoding. + value = value.reshape(retShapedType); } addNamedAttrs(rewriter.replaceOpWithNewOp( op, retShapedType, value), @@ -545,6 +537,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, TritonExpandDimsPattern, TritonTransPattern, TritonDotPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 6ef40db00f52..c964bdcea534 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Triton) +add_subdirectory(TritonCPU) add_subdirectory(TritonGPU) add_subdirectory(TritonNvidiaGPU) diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 752daa7ff055..f9d1586441d4 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonIR Ops.cpp Traits.cpp Types.cpp + OpInterfaces.cpp DEPENDS TritonTableGen diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index dc24177125a6..2874a3f5649d 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/UB/IR/UBOps.h" -#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" @@ -12,8 +11,9 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc" #include "triton/Dialect/Triton/IR/Dialect.cpp.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" using namespace mlir; using namespace mlir::triton; @@ -77,44 +77,6 @@ struct TritonInlinerInterface : public DialectInlinerInterface { } }; -struct TensorModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getRank(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementTypeBitWidth(); - } -}; - -struct MemDescModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getShape().size(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementType().getIntOrFloatBitWidth(); - } -}; - } // namespace void TritonDialect::initialize() { @@ -127,9 +89,6 @@ void TritonDialect::initialize() { // We can also add interface here. addInterfaces(); - - RankedTensorType::attachInterface(*getContext()); - MemDescType::attachInterface(*getContext()); } Operation *TritonDialect::materializeConstant(OpBuilder &builder, diff --git a/lib/Dialect/Triton/IR/OpInterfaces.cpp b/lib/Dialect/Triton/IR/OpInterfaces.cpp new file mode 100644 index 000000000000..7f3a966bffdb --- /dev/null +++ b/lib/Dialect/Triton/IR/OpInterfaces.cpp @@ -0,0 +1,34 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace triton { +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op) { + TransposeOpInterface transposeOp = cast(op); + auto rank = cast(transposeOp.getSrc().getType()).getRank(); + auto order = transposeOp.getOrder(); + if (rank != order.size()) { + return op->emitError( + "order must have the same size as the rank of the operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return op->emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +} // namespace impl +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index c2c057f42c51..a5d8dc3646ee 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" namespace mlir { namespace triton { @@ -198,6 +199,11 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) { return getResult(); } + // Eliminate splat constant transpose ops. + if (auto attr = + llvm::dyn_cast_if_present(adaptor.getSrc())) + return attr.reshape(getType()); + return {}; } @@ -206,7 +212,7 @@ LogicalResult TransOp::inferReturnTypes( DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // type is the same as the input - auto argTy = cast(operands[0].getType()); + auto argTy = cast(operands[0].getType()); auto order = properties.as()->order.asArrayRef(); SmallVector retShape = applyPermutation(argTy.getShape(), order); @@ -222,35 +228,8 @@ LogicalResult TransOp::inferReturnTypes( return failure(); } } - if (auto memDescTy = dyn_cast(argTy)) { - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); - } else { - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -LogicalResult TransOp::verify() { - // Check that the op's `order` attribute is a permutation of the right length. - auto srcTy = getSrc().getType(); - - ArrayRef order = getOrder(); - if (order.size() != srcTy.getRank()) { - return emitError("order must have the same size as the rank of the " - "operand and result"); - } - - SmallVector sortedOrder(order); - llvm::sort(sortedOrder); - for (int32_t i = 0; i < sortedOrder.size(); i++) { - if (sortedOrder[i] != i) { - return emitError("order must be a permutation of [0, ..., rank - 1]"); - } - } - + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); return success(); } @@ -265,8 +244,8 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional location, inferredReturnTypes.push_back(accTy); // verify encodings - auto aEnc = cast(operands[0].getType()).getEncoding(); - auto bEnc = cast(operands[1].getType()).getEncoding(); + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); auto retEnc = accTy.getEncoding(); if (aEnc) { assert(bEnc && retEnc); @@ -503,6 +482,22 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- @@ -728,9 +723,40 @@ LogicalResult ReshapeOp::verify() { } //-- FpToFpOp -- + +// Fold FpToFpOp when the input operand is a constant zero. +OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) { + auto srcVal = getSrc(); + auto dstTy = getType(); + + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &semantic = resElemType.getFloatSemantics(); + + if (matchPattern(srcVal, m_PosZeroFloat())) { + llvm::APFloat posZero = + llvm::APFloat::getZero(semantic, /*negative=*/false); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, posZero); + return Builder(getContext()).getFloatAttr(resElemType, posZero); + } + + if (matchPattern(srcVal, m_NegZeroFloat())) { + llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, negZero); + return Builder(getContext()).getFloatAttr(resElemType, negZero); + } + + return {}; +} + LogicalResult FpToFpOp::verify() { - auto dstType = getType().getElementType(); - auto srcType = getSrc().getType().getElementType(); + auto dstType = getType(); + auto srcType = getSrc().getType(); + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && (!getRounding().has_value())) { return emitError("Rounding mode is required for FP downcast"); @@ -800,6 +826,15 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, builder.getDenseI32ArrayAttr(order)); } +//-- AddPtrOp -- +OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { + // addptr(ptr, 0) -> ptr + if (matchPattern(adaptor.getOffset(), m_Zero())) { + return getPtr(); + } + return {}; +} + //-- AdvanceOp -- OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { // advance(ptr, 0, 0) -> ptr @@ -813,6 +848,22 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { return getPtr(); } +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape) { + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = TensorDescType::get(builder.getContext(), blockTy); + return build(builder, state, descTy, base, shape, strides); +} + // The following ops, including `call`, `func`, and `return` are copied and // modified from // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -1016,6 +1067,60 @@ void ExternElementwiseOp::getEffects( SideEffects::DefaultResource::get()); } +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + // -- ExperimentalTensormapCreateOp -- LogicalResult ExperimentalTensormapCreateOp::verify() { auto rank = getBoxDim().size(); diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index 19729aee5c1b..690826f4efaf 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -5,11 +5,9 @@ #include "mlir/IR/TypeUtilities.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; -namespace ttg = mlir::triton::gpu; static LogicalResult verifySameEncoding(Type typeA, Type typeB, bool allowTensorPointerType) { @@ -118,53 +116,12 @@ LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { if (!layout) return success(); - if (isa(layout)) - return makeErr() << "Shared layout is not allowed on tensor type."; - // TODO(jlebar): Currently this only checks blocked layouts, but other - // layouts also have invariants! - - // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. - if (auto blocked = dyn_cast(layout)) { - // A different verifier should have checked that the layout itself is - // valid, including that threads-per-warp has the same rank as - // warps-per-block etc. - auto layoutRank = blocked.getThreadsPerWarp().size(); - if (layoutRank != rankedTy.getRank()) { - return makeErr() << layout << ".\nLayout has rank " << layoutRank - << ", but the tensor it's attached to has rank " - << rankedTy.getRank() << "."; - } - - int moduleThreadsPerWarp = - ttg::TritonGPUDialect::getThreadsPerWarp(module); - int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); - if (layoutThreadsPerWarp != moduleThreadsPerWarp) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutThreadsPerWarp - << " threads per warp, but the module specifies " - << moduleThreadsPerWarp << " threads per warp."; - } - - int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); - int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); - if (layoutWarpsPerCTA != moduleWarpsPerCTA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutWarpsPerCTA - << " warps per CTA, but the module specifies " - << moduleWarpsPerCTA << " warps per CTA."; - } - - if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { - int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); - int64_t layoutCTAsPerCGA = - product(blocked.getCTALayout().getCTAsPerCGA()); - if (layoutCTAsPerCGA != moduleCTAsPerCGA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutCTAsPerCGA - << " CTAs per CGA, but the module specifies " - << moduleCTAsPerCGA << " CTAs per CGA."; - } - } + Dialect &dialect = layout.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, module, + makeErr); } return success(); diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 6e41e70a8e39..de8925cbffd7 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -50,61 +50,6 @@ void PointerType::print(AsmPrinter &printer) const { } } -static constexpr llvm::StringRef kMutableMemory = "mutable"; - -Type MemDescType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) - return Type(); - - // Parse the element type. - Type elementType; - if (parser.parseType(elementType)) - return Type(); - - Attribute encoding; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(encoding)) - return Type(); - } - bool mutableMemory = false; - Attribute memorySpace; - if (succeeded(parser.parseOptionalComma())) { - if (failed(parser.parseOptionalKeyword(kMutableMemory))) { - if (parser.parseAttribute(memorySpace)) - return Type(); - } else { - mutableMemory = true; - } - } - if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { - if (parser.parseOptionalKeyword(kMutableMemory)) - return Type(); - mutableMemory = true; - } - if (parser.parseGreater()) - return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory); -} - -void MemDescType::print(AsmPrinter &printer) const { - printer << "<"; - for (auto dim : getShape()) - printer << dim << "x"; - printer << getElementType(); - if (getEncoding()) - printer << ", " << getEncoding(); - if (getMemorySpace()) - printer << ", " << getMemorySpace(); - if (getMutableMemory()) - printer << ", " << kMutableMemory; - printer << ">"; -} - namespace mlir { namespace triton { diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 33c4516b47f5..fa909d4df94c 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -7,7 +7,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" @@ -18,35 +17,7 @@ namespace mlir::triton { namespace { bool isZero(Value val) { - if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) - return true; - // broadcast(constant_0) - if (auto bc = val.getDefiningOp()) { - if (matchPattern(bc.getSrc(), m_Zero()) || - matchPattern(bc.getSrc(), m_AnyZeroFloat())) - return true; - } - return false; -} - -bool isBroadcastConstantCombinable(Attribute value) { - if (auto denseValue = dyn_cast(value)) { - return denseValue.isSplat(); - } - return isa(value); -} - -DenseElementsAttr getConstantValue(Builder &builder, Attribute value, - Value bcast_res) { - auto resType = cast(bcast_res.getType()); - DenseElementsAttr res; - if (auto denseValue = dyn_cast(value)) { - res = - DenseElementsAttr::get(resType, denseValue.getSplatValue()); - } else { - res = DenseElementsAttr::get(resType, value); - } - return res; + return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())); } bool isAddPtrOffsetCombinable(Value first, Value second) { @@ -231,7 +202,6 @@ class CombineOpsPass : public TritonCombineOpsBase { // %} patterns.add(context); patterns.add(context); - patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 5a2fcecfa949..e3588f587757 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -44,11 +44,4 @@ def CombineAddPtrPattern : Pat< (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), [(Constraint> $idx0, $idx1)]>; -// broadcast(cst) => cst -def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; -def CombineBroadcastConstantPattern : Pat< - (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), - (Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)), - [(Constraint> $value)]>; - #endif diff --git a/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp index 257e734b7f88..cb25d41a2548 100644 --- a/lib/Dialect/Triton/Transforms/LoopUnroll.cpp +++ b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -22,8 +22,6 @@ namespace mlir::triton { -static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; - namespace { class LoopUnrollPass : public TritonLoopUnrollBase { @@ -31,12 +29,15 @@ class LoopUnrollPass : public TritonLoopUnrollBase { int getUnrollFactorOrDefault(scf::ForOp forOp) { // Use the attribute attached to the loop if it exists otherwise set the // factor to 1 to suppress the unrolling. - if (auto factor = forOp->getAttrOfType( - mlir::triton::loopUnrollFactorAttrName)) + if (auto factor = + forOp->getAttrOfType(loopUnrollFactorAttrName)) return factor.getInt(); return 1; } + const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + const char *pipelineStagesAttrName = "tt.num_stages"; + public: LoopUnrollPass() = default; LoopUnrollPass(const LoopUnrollPass &) {} @@ -49,11 +50,18 @@ class LoopUnrollPass : public TritonLoopUnrollBase { loops.push_back(forOp); }); + auto ctx = getOperation()->getContext(); for (auto loop : loops) { auto unrollFactor = getUnrollFactorOrDefault(loop); - loop->removeAttr(mlir::triton::loopUnrollFactorAttrName); + loop->removeAttr(loopUnrollFactorAttrName); LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); - (void)loopUnrollByFactor(loop, unrollFactor); + auto resultLoops = loopUnrollByFactor(loop, unrollFactor); + // Do not pipeline the epilog loop. + if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) { + (*resultLoops->epilogueLoopOp) + ->setAttr(pipelineStagesAttrName, + mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1)); + } } } }; diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 43479a3d9f9a..486fc1c7b9da 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -206,18 +206,6 @@ struct MoveBroadcastAfterElementwisePattern } }; -template -class CanonicalizePattern : public OpRewritePattern { -public: - explicit CanonicalizePattern(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpType op, - PatternRewriter &rewriter) const override { - return OpType::canonicalize(op, rewriter); - } -}; - class ReorderBroadcastPass : public ::impl::TritonReorderBroadcastBase { public: @@ -226,8 +214,8 @@ class ReorderBroadcastPass RewritePatternSet patterns(context); ModuleOp m = getOperation(); - patterns.add>(context); - patterns.add>(context); + BroadcastOp::getCanonicalizationPatterns(patterns, context); + ExpandDimsOp::getCanonicalizationPatterns(patterns, context); // elementwise(broadcast(a)) => broadcast(elementwise(a)) patterns.add(context); // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) diff --git a/lib/Dialect/TritonCPU/CMakeLists.txt b/lib/Dialect/TritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/lib/Dialect/TritonCPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/TritonCPU/IR/CMakeLists.txt b/lib/Dialect/TritonCPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..c0b6f0f7be24 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonCPUIR + Dialect.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonCPUTableGen + TritonCPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR +) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp new file mode 100644 index 000000000000..41a4c62bda45 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -0,0 +1,75 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" + +void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +/// Parse an attribute registered to this dialect. +::mlir::Attribute +TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const { + llvm_unreachable("parse stub called"); +} + +/// Print an attribute registered to this dialect. +void TritonCPUDialect::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &os) const { + llvm_unreachable("print stub called"); +} + +void ExtractIndicesOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, Value src) { + assert(triton::isTensorPointerType(src.getType()) && + "Unexecpeted source type"); + auto tensorTy = dyn_cast( + dyn_cast(src.getType()).getPointeeType()); + SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); + build(builder, state, resTypes, src); +} + +void TritonCPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc" + >(); +} + +// verify TritonCPU ops +LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/lib/Dialect/TritonCPU/IR/Ops.cpp b/lib/Dialect/TritonCPU/IR/Ops.cpp new file mode 100644 index 000000000000..b8523ebcd8ac --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -0,0 +1,40 @@ +#include "mlir/IR/Builders.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::cpu { + +LogicalResult PrintOp::verify() { + if (getOperands().size() > 1) + return emitOpError("expects at most one operand"); + return success(); +} + +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + return success(); +} + +} // namespace mlir::triton::cpu diff --git a/lib/Dialect/TritonCPU/IR/Types.cpp b/lib/Dialect/TritonCPU/IR/Types.cpp new file mode 100644 index 000000000000..b6a17786bac2 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonCPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::cpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::cpu::TritonCPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonCPU/IR/Types.cpp.inc" + >(); +} diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 98831f0db8ac..7486d72f36e3 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -7,6 +7,7 @@ add_triton_library(TritonGPUIR DEPENDS TritonGPUTableGen TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen LINK_LIBS PUBLIC MLIRGPUDialect diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 71506ecbb9f0..dec78c2e41a4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -11,6 +11,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" @@ -19,6 +20,7 @@ // Include TableGen'erated code #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc" using namespace mlir; using namespace mlir::triton; @@ -201,12 +203,25 @@ SmallVector getUniqueContigPerThread(Attribute layout, } return ret; } - -SmallVector getShapePerCTATile(Attribute layout, - ArrayRef tensorShape) { +SmallVector getShapePerCTATile(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { - return distributedLayout.getShapePerCTATile(tensorShape); + auto sizePerThread = distributedLayout.getSizePerThread(); + auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); + // ThreadsPerWarp does not align with this function for slice layout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + } + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + assert(sizePerThread.size() == threadsPerWarp.size() && + sizePerThread.size() == warpsPerCTA.size()); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; } else { llvm::report_fatal_error("getShapePerCTATile not implemented"); return SmallVector(); @@ -217,7 +232,7 @@ bool isExpensiveView(Type srcType, Type dstType) { return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); } -/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. +/* Utility function used by get.*Order methods of SliceEncodingAttr. * Erase dim and decrease all values larger than dim by 1. * Example: order = [0, 2, 4, 3, 1], dim = 2 * resOrder = [0, 3, 2, 1] @@ -235,6 +250,19 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kMajor) { // kMajor: if true, the matrix is fastest-running on k, @@ -244,42 +272,28 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - // If opIdx is 1 and kMajor is true, the order is [0, 1] - // (resp. [1, 2, 0] if rank == 3) - // Same if opIdx is 0 and kMajor is false - if (bool(opIdx) == kMajor) { - std::swap(order[0], order[1]); - } - return order; + auto rowMajor = bool(opIdx) != kMajor; + return getMatrixOrder(rank, rowMajor); +} + +SmallVector getRepOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; } SmallVector getWarpOrder(Attribute layout) { - if (auto dotLayout = dyn_cast(layout)) { - if (isa(dotLayout.getParent())) { - return getWarpOrder(dotLayout.getParent()); - } - } - auto order = getOrder(layout); - // FIXME: This mmaLayout if should just return - // getOrderForDotOperand(0, order.size(), kMajor=false) - // as mma has the same order as DotOperand(opIdx=0) - if (auto mmaLayout = dyn_cast(layout)) { - if (mmaLayout.isHopper()) { - // Hopper MMA instructions force a warp order of [0, 1]. See docs: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - auto it = std::find(order.begin(), order.end(), 0); - order.erase(it); - order.insert(order.begin(), 0); - } - } else if (auto dotOpLayout = dyn_cast(layout)) { - order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), - /*kMajor*/ false); - } - return order; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getWarpOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; } +// Returns the order of the elements in a layout from the fastest running +// dimension to the slowest SmallVector getOrder(Attribute layout) { if (auto blockedLayout = dyn_cast(layout)) { return llvm::to_vector(blockedLayout.getOrder()); @@ -287,9 +301,7 @@ SmallVector getOrder(Attribute layout) { if (auto mmaLayout = dyn_cast(layout)) { auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - return order; + return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); @@ -308,10 +320,13 @@ SmallVector getOrder(Attribute layout) { if (auto sharedLayout = mlir::dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } + if (auto linearLayout = mlir::dyn_cast(layout)) { + return linearLayout.getOrder(); + } llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; -}; +} SmallVector getThreadOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) @@ -319,7 +334,7 @@ SmallVector getThreadOrder(Attribute layout) { else llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); return {}; -}; +} CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = @@ -421,7 +436,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto wmmaLayout = dyn_cast(layout)) warpsPerCTA = wmmaLayout.getWarpsPerCTA(); else if (auto dotLayout = dyn_cast(layout)) - return getNumWarpsPerCTA(dotLayout.getParent()); + warpsPerCTA = dotLayout.getWarpsPerCTA(); else if (auto sharedLayout = dyn_cast(layout)) llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); else @@ -531,6 +546,132 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + + auto kRegister = StringAttr::get(ctx, "register"); + std::set broadcastedDims; + + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + if (!broadcastRegisters && inDimName == kRegister) { + broadcastedDims.insert(basisIdx); + } else { + bases[inDimName][basisIdx][outDim.index()] = 0; + } + actualSize >>= 1; + } + } + if (!broadcastRegisters) { + // Remove broadcasted registers + std::vector> newBasesRegister; + for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { + // Remove if it's broadcasted + if (broadcastedDims.find(idx) == broadcastedDims.end()) { + newBasesRegister.push_back(std::move(basis)); + } + } + bases[kRegister] = std::move(newBasesRegister); + } + + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Returns ["dim0", "dim1", ..., "dim"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to +// creating a 1D -> 1D mapping of size product(shape) and then reshaping to +// permute(shape, order). +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + MLIRContext *ctx = inDimName.getContext(); + auto rank = shape.size(); + + // The order in triton is written wrt. [dim0, dim1, ...]. + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + } // namespace gpu } // namespace triton } // namespace mlir @@ -621,10 +762,10 @@ static void maybePrintCTALayout(mlir::MLIRContext *context, //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" #define GET_ATTRDEF_CLASSES -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { return SliceEncodingAttr::get(getContext(), axis, *this); @@ -654,6 +795,9 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, // If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. // But we need to have a consistent interface with e.g. SliceEncodingAttr, which // computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -678,14 +822,6 @@ SmallVector BlockedEncodingAttr::getThreadOrder() const { SmallVector BlockedEncodingAttr::getSizePerThread() const { return SmallVector(getSizePerThread__()); } -SmallVector -BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape; - for (unsigned d = 0, n = getOrder().size(); d < n; ++d) - shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * - getWarpsPerCTA()[d]); - return shape; -} template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { @@ -720,6 +856,10 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); } +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = ::getRepOrder(getParent()); + return eraseOrder(parentRepOrder, getDim()); +} SmallVector SliceEncodingAttr::getCTASplitNum() const { SmallVector res = ::getCTASplitNum(getParent()); res.erase(res.begin() + getDim()); @@ -762,7 +902,8 @@ SmallVector SliceEncodingAttr::getWarpsPerCTA() const { return warpsPerCTA; } SmallVector SliceEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto parentWarpOrder = ::getWarpOrder(getParent()); + return eraseOrder(parentWarpOrder, getDim()); } SmallVector SliceEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); @@ -774,19 +915,14 @@ SmallVector SliceEncodingAttr::getThreadsPerWarp() const { return threadsPerWarp; } SmallVector SliceEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto parentThreadOrder = ::getThreadOrder(getParent()); + return eraseOrder(parentThreadOrder, getDim()); } SmallVector SliceEncodingAttr::getSizePerThread() const { auto sizePerThread = ::getSizePerThread(getParent()); sizePerThread.erase(sizePerThread.begin() + getDim()); return sizePerThread; } -SmallVector -SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); - shape.erase(shape.begin() + getDim()); - return shape; -} // @@ -861,28 +997,13 @@ NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, size_t rank = shape.size(); assert(rank == 2 || (rank == 3 && isAmpere()) && "Unexpected rank of mma layout"); - assert((isVolta() || isAmpere() || isHopper()) && + assert((isAmpere() || isHopper()) && "For NvidiaMmaEncodingAttr only version 1~3 is supported"); auto shapePerCTA = getShapePerCTA(getCTALayout().getCTASplitNum(), shape); SmallVector elemsPerThread(rank); - if (isVolta()) { - auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates(); - static constexpr std::array fpw{{2, 2}}; - unsigned packSize0 = (isARow || isAVec4) ? 1 : 2; - unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1; - unsigned repM = 2 * packSize0; - unsigned repN = 2 * packSize1; - unsigned spwM = fpw[0] * 4 * repM; - unsigned spwN = fpw[1] * 4 * repN; - unsigned wptM = getWarpsPerCTA()[0]; - unsigned wptN = getWarpsPerCTA()[1]; - unsigned resM = repM * std::max(1, shapePerCTA[0] / (spwM * wptM)); - unsigned resN = 2 * repN * std::max(1, shapePerCTA[1] / (spwN * wptN)); - elemsPerThread[0] = resM; - elemsPerThread[1] = resN; - } else if (isAmpere()) { + if (isAmpere()) { unsigned elemsRow = ceil(shapePerCTA[rank - 2], 16 * getWarpsPerCTA()[rank - 2]) * 2; @@ -907,36 +1028,6 @@ NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, return elemsPerThread; } -unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( - int opIdx, ArrayRef shape) const { - size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of mma layout"); - auto shapePerCTA = getShapePerCTA(*this, shape); - int res = 0; - if (isVolta()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 1"); - } else if (isAmpere()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 2"); - } else if (isHopper()) { - auto wpt = getWarpsPerCTA(); - auto instrMNK = getInstrShape(); - if (opIdx == 0) { - int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); - int repK = ceil(shapePerCTA[1], instrMNK[2]); - return 8 * repM * repK; - - } else if (opIdx == 1) { - int repK = ceil(shapePerCTA[0], instrMNK[2]); - int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); - // benzh@ here need more check - return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; - } - } - return res; -} - unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); @@ -959,25 +1050,36 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, SmallVector DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { + auto rank = shape.size(); + assert(rank == 2 || rank == 3); - if (auto parent = mlir::dyn_cast(getParent())) { - auto rank = shape.size(); - assert(rank == 2 || rank == 3); - - auto idx = getOpIdx(); - assert(idx == 0 || idx == 1); - - SmallVector elemsPerThread(rank); + auto idx = getOpIdx(); + assert(idx == 0 || idx == 1); - auto kWidth = getKWidth(); - auto rep = parent.getRepForOperand(shape, kWidth, idx); + SmallVector elemsPerThread(rank); + auto parent = getParent(); + auto kWidth = getKWidth(); + if (auto mfma = mlir::dyn_cast(parent)) { + auto rep = mfma.getRepForOperand(shape, kWidth, idx); if (rank == 3) elemsPerThread[0] = rep[0]; elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth; elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2]; - return elemsPerThread; + } else if (auto mma = mlir::dyn_cast(parent)) { + assert(getCTALayout(*this) == + CTALayoutAttr::getDefault(getContext(), rank) && + "NYI"); + auto sizePerThread = getSizePerThread(); + auto threadsPerWarp = getThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector regs; + for (auto [n, nsize, nThread, nWarp] : + llvm::zip(shape, sizePerThread, threadsPerWarp, warpsPerCTA)) { + regs.push_back(std::max(nsize, n / (nThread * nWarp))); + } + return regs; } llvm_unreachable("getElemsPerThread is not supported for dot operand"); @@ -987,14 +1089,24 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { - return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), - getOpIdx()); + if (auto nvidiaMmaParent = + mlir::dyn_cast(mmaParent)) { + return product(getElemsPerThread(shape, eltTy)); + } + if (auto amdMfmaParent = mlir::dyn_cast(getParent())) { + return amdMfmaParent.getTotalElemsPerThreadForOperand( + shape, eltTy, getKWidth(), getOpIdx()); + } + if (auto amdWmmaParent = mlir::dyn_cast(getParent())) { + return amdWmmaParent.getTotalElemsPerThreadForOperand( + shape, eltTy, getKWidth(), getOpIdx()); + } } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); + auto sizePerThread = blockedLayout.getSizePerThread(); int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; @@ -1030,7 +1142,8 @@ SmallVector DotOperandEncodingAttr::getCTASplitNum() const { assert(rank == 2 || rank == 3 && "Invalid dotLayout"); // Do not split CTA in K dimension - getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1; + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + res[kDim] = 1; return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { @@ -1042,74 +1155,68 @@ SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + // FIXME(Lezcano): Preexisting. Do we want to have this path at all? + if (mlir::isa(getParent())) { + return ::getWarpOrder(getParent()); + } + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented"); + return {}; } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), /*kMajor*/ true); } -SmallVector DotOperandEncodingAttr::getShapePerCTATile( - ArrayRef tensorShape) const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForOperand( - tensorShape, getKWidth(), getOpIdx()); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " - "supported yet"); - } -} LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned opIdx, Attribute parent, unsigned kWidth) { if (opIdx != 0 && opIdx != 1) { - return emitError() - << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " - << opIdx; + return emitError() << "ttg.dot_op opIdx paramenter can be 0 or 1, got: " + << opIdx; } if (!parent) { - return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; + return emitError() << "ttg.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !parentAttr.isAmpere()) - return emitError() << "triton_gpu.dot_op kWidth parameter can only be " - "non-zero for Ampere MMA parent"; - if (kWidth == 0 && parentAttr.isAmpere()) + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) return emitError() - << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere MMA parent"; + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth != 16 && parentAttr.getVersion() == 1 || kWidth != 8 && parentAttr.getVersion() == 2) - return emitError() << "triton_gpu.dot_op kWidth parameter must be 16 for " + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " "gfx11 and 8 for gfx12"; return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth == 0) - return emitError() - << "triton_gpu.dot_op kWidth parameter is mandatory for " - "MFMA parent"; + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; return success(); } if (auto parentAttr = mlir::dyn_cast(parent)) { if (kWidth != 0) - return emitError() - << "triton_gpu.dot_op kWidth parameter is not supported " - "when the parent is a blocked layout"; + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; return success(); } - return emitError() << "triton_gpu.dot_op unexpected parent layout: " - << parent; + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; } //===----------------------------------------------------------------------===// @@ -1211,6 +1318,360 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"register", "lane", "warp", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(linearLayout)); +} + +SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, + StringAttr dimName, size_t rank, + bool skipBroadcast = true) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = -1; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + assert(nonZeroIdx != -1); + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector basesPerDim(const LinearLayout &ll, StringAttr dimName, + bool skipBroadcast = true) { + auto shapeIter = ll.getOutDimSizes(); + auto rank = std::distance(shapeIter.begin(), shapeIter.end()); + return basesPerDim(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector orderPerDim(const LinearLayout &ll, StringAttr dimName, + ArrayRef defaultOrder) { + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} +SmallVector LinearEncodingAttr::getCTAsPerCGA() const { + // CTAs are split into an identity part (SplitNum) and a broadcast part + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + /*skipBroadcast=*/false); +} +SmallVector LinearEncodingAttr::getCTAOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + getOrder()); +} +SmallVector LinearEncodingAttr::getCTASplitNum() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block")); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp"), + getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane"), + getOrder()); +} +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getRepOrder().size(); + auto ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : + llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[StringAttr::get(ctx, "register")]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDim(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(getLinearLayout(), + StringAttr::get(getContext(), "register"), order); +} + +std::optional +LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { + // We can relax this assert by calling toLinearLayout rather than + // getLinearLayout + SmallVector shapeVec(shape.begin(), shape.end()); + assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); + auto ll = getLinearLayout(); + return basesPerDim(ll, StringAttr::get(getContext(), "register")); +} + +// Start of Selection +SmallVector LinearEncodingAttr::getContigPerThread() const { + auto ll = getLinearLayout(); + const auto ®s = + ll.getBases().find(StringAttr::get(getContext(), "register"))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(rank, 1); + auto regIt = regs.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = 1; + + while (regIt != regs.end() && *regIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++regIt; + } + } + return contig; +} + +unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// @@ -1562,16 +2023,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= getMDim(); - shapePerCTATile[rank - 2] *= getNDim(); - return shapePerCTATile; -} - SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1585,7 +2036,7 @@ SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { auto order = ::getOrder(*this); @@ -1658,6 +2109,17 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { return {kDim, nDim}; } +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + SmallVector AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const { @@ -1704,43 +2166,21 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - assert(getMDim() == 32 || getMDim() == 16); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], 32}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; - } else if (opIdx == 1) { - if (rank == 2) - return {32, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } - llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); -} - //===----------------------------------------------------------------------===// // Wmma encoding //===----------------------------------------------------------------------===// -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} - auto mnkDim = getMNKDimPerInstr(); - shapePerCTATile[rank - 2] *= mnkDim[0]; - shapePerCTATile[rank - 1] *= mnkDim[1]; - return shapePerCTATile; +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); } + SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1754,7 +2194,7 @@ SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); @@ -1794,21 +2234,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - auto parentShapePerCTA = getShapePerCTATile(shape); - auto rank = shape.size(); - assert(rank == 2); - if (opIdx == 0) { - return {parentShapePerCTA[0], static_cast(shape[1])}; - } else if (opIdx == 1) { - return {static_cast(shape[0]), parentShapePerCTA[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } -} - unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); @@ -1865,6 +2290,10 @@ bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1878,16 +2307,15 @@ SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto rank = getWarpsPerCTA().size(); + // Hopper (wgmma) uses column-major as this is embeded in the instruction + // For Ampere we can choose either row-major or column-major. + // We choose row-major as the legacy path did so + return getMatrixOrder(rank, /*rowMajor*/ !isHopper()); } SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { auto rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); - if (isVolta()) { - res[rank - 2] = 4; - res[rank - 1] = 8; - return res; - } if (isAmpere()) { res[rank - 2] = 8; res[rank - 1] = 4; @@ -1902,21 +2330,17 @@ SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { "getThreadsPerWarp not implemented for unknown Mma version "); } SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); } SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { - auto rank = ::getOrder(*this).size(); + auto rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); if (isAmpere()) { res[rank - 2] = 2; res[rank - 1] = 2; return res; } - if (isVolta()) { - res[rank - 2] = 1; - res[rank - 1] = 2; - return res; - } if (isHopper()) { auto instrShape = getInstrShape(); // WGMMA instructions have an order of [0, 1] with 4 warps, each with 8 @@ -1929,231 +2353,62 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { } SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - if (isAmpere()) { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), - warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= 8; - shapePerCTATile[rank - 2] *= 16; - return shapePerCTATile; - } - if (isVolta()) { - assert(!tensorShape.empty() && "Volta needs the tensorShape"); - if (tensorShape.size() == 1) // must be SliceEncoding - return {static_cast(tensorShape[0]), - static_cast(tensorShape[0])}; - return {static_cast(tensorShape[0]), - static_cast(tensorShape[1])}; - } - if (isHopper()) { - auto instrShape = getInstrShape(); - return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; - } - llvm::report_fatal_error("Unexpected MMA layout version found"); -} - -// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor -std::tuple -NvidiaMmaEncodingAttr::decodeVoltaLayoutStates() const { - unsigned versionMinor = getVersionMinor(); - bool isARow = versionMinor & (1 << 0); - bool isBRow = versionMinor & (1 << 1); - bool isAVec4 = versionMinor & (1 << 2); - bool isBVec4 = versionMinor & (1 << 3); - - int id = 0; - for (int i = numBitsToHoldMmaV1ID - 1; i >= 0; --i) - id = (id << 1) + static_cast(versionMinor & (1 << (4 + i))); - - return std::make_tuple(isARow, isBRow, isAVec4, isBVec4, id); +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); } -bool NvidiaMmaEncodingAttr::getMMAv1IsRow(int opIdx) const { - auto [isARow, isBRow, _0, _1, _2] = decodeVoltaLayoutStates(); - return opIdx == 0 ? isARow : isBRow; -} -bool NvidiaMmaEncodingAttr::getMMAv1IsVec4(int opIdx) const { - auto [_0, _1, isAVec4, isBVec4, _2] = decodeVoltaLayoutStates(); - return opIdx == 0 ? isAVec4 : isBVec4; -} -int NvidiaMmaEncodingAttr::getMMAv1NumOuter(ArrayRef shape, - int opIdx) const { - auto spw = getMMAv1ShapePerWarp(opIdx); - auto rep = getMMAv1Rep(opIdx); - auto warpsPerCTA = getWarpsPerCTA(); - if (opIdx == 0) { - return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]); - } else { - return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]); - } -} -SmallVector NvidiaMmaEncodingAttr::getMMAv1Rep(int opIdx) const { - auto [isARow, isBRow, isAVec4, isBVec4, _] = decodeVoltaLayoutStates(); - // A - if (opIdx == 0) { - int packSize = (isARow || isAVec4) ? 1 : 2; - return {2 * packSize, 0, 1}; - } - // B - else { - int packSize = (isBRow && !isBVec4) ? 2 : 1; - return {0, 2 * packSize, 1}; - } -} -SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { - auto rep = getMMAv1Rep(opIdx); - if (opIdx == 0) { - return {8 * rep[0], 0, 1}; - } else { - return {0, 8 * rep[1], 1}; - } -} -int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { - return 2 * getMMAv1Rep(opIdx)[opIdx]; -} -SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( - ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int kWidth, int opIdx) const { + assert( + kWidth >= 32 / bitwidth && + "kWidth must be >= 32 / bitwidth for this function to be well-defined"); auto rank = shape.size(); + // Broadcast long K auto warpsPerCTA = getWarpsPerCTA(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + warpsPerCTA[kDim] = 1; - SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; - int numRepBatch = - rank == 3 - ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) - : 1; - assert(isAmpere()); - - if (opIdx == 0) - return {numRepBatch, - std::max(1, shape[rank - 2] / - (shapePerWarp[1] * warpsPerCTA[rank - 2])), - std::max(1, shape[rank - 1] / shapePerWarp[3])}; - else { - assert(opIdx == 1); - return {numRepBatch, - std::max(1, shape[rank - 2] / shapePerWarp[3]), - std::max(1, shape[rank - 1] / (shapePerWarp[2] * - warpsPerCTA[rank - 1]))}; + SmallVector tileSize; + if (rank == 3) { + tileSize.push_back(1); } -} -unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( - ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto shapePerCTA = getShapePerCTA(*this, shape); - int warpsPerCTAM = getWarpsPerCTA()[0]; - int warpsPerCTAN = getWarpsPerCTA()[1]; - // H100 - if (isHopper()) { - return getTotalElemsPerThread(shape, eltTy); + if (opIdx == 0) { + // m x k + tileSize.push_back(16); + tileSize.push_back(4 * 64 / bitwidth); + } else { + // k x n + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF + // so it's fine if the n is incorrect here + tileSize.push_back(4 * 64 / bitwidth); + tileSize.push_back(8); } - // A100 - if (isAmpere()) { - auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), - kWidth, opIdx); - if (opIdx == 0) - return 4 * rep[0] * rep[1] * rep[2]; - if (opIdx == 1) - return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); - } - // V100 - if (isVolta()) { - bool isRow = getMMAv1IsRow(opIdx); - bool isVec4 = getMMAv1IsVec4(opIdx); - if (opIdx == 0) { - int packSizeM = (isRow || isVec4) ? 1 : 2; - int repM = 2 * packSizeM; - int spwM = 2 * 4 * repM; - int numM = getMMAv1NumOuter(shape, opIdx); - int NK = shape[1]; - int vec = 2 * repM; - // Here we mimic the logic in loadA, the result cannot be calculated - // directly. - llvm::DenseSet> visited; - auto ld = [&](int m, int k) { - visited.insert({m, k}); - if (vec > 4) { - if (isRow) - visited.insert({m, k + 4}); - else - visited.insert({m + 1, k}); - } - }; - for (unsigned k = 0; k < NK; k += 4) - for (unsigned m = 0; m < numM / 2; ++m) - if (!visited.count({m, k})) - ld(m, k); - return visited.size() * 2; - } - if (opIdx == 1) { - int packSizeN = (isRow && !isVec4) ? 2 : 1; - int repN = 2 * packSizeN; - int spwN = 2 * 4 * repN; - int numN = getMMAv1NumOuter(shape, opIdx); - int vec = 2 * repN; - - int NK = shape[0]; - // Here we mimic the logic in loadA, the result cannot be calculated - // directly. - llvm::DenseSet> visited; - int elemsPerLd = vec > 4 ? 4 : 2; - auto ld = [&](int n, int k) { - visited.insert({n, k}); - if (vec > 4) { - if (isRow) - visited.insert({n + 1, k}); - else - visited.insert({n, k + 4}); - } - }; - for (unsigned k = 0; k < NK; k += 4) - for (unsigned n = 0; n < numN / 2; ++n) { - if (!visited.count({n, k})) - ld(n, k); - } - - return visited.size() * 2; - } + SmallVector numRep; + // Lezcano: This is odd. Why do we always return a vector of size 3? + if (rank != 3) { + numRep.push_back(1); } - llvm_unreachable("unknown mma layout"); -} -SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( - ArrayRef shape, int kWidth, int opIdx) const { - assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); - // 4 threads * 2 subtiles - unsigned kWidthTile = kWidth * 2 * 4; - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], kWidthTile}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], - kWidthTile}; - } else if (opIdx == 1) { - if (rank == 2) - return {kWidthTile, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], kWidthTile, - parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) { + numRep.push_back(std::max(1, s / (size * warp))); } + return numRep; } + SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { - assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); auto rank = getWarpsPerCTA().size(); auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { sizePerThread[rank - 2] = 2; sizePerThread[rank - 1] = 2 * kWidth; - } else if (opIdx == 1) { + } else { + assert(opIdx == 1); sizePerThread[rank - 2] = 2 * kWidth; sizePerThread[rank - 1] = 1; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } return sizePerThread; } @@ -2161,6 +2416,15 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); if (auto mma = mlir::dyn_cast(parent)) { @@ -2195,6 +2459,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + // Encoding attributes if (auto mmaAttr = mlir::dyn_cast(attr)) { os << "mma"; return AliasResult::FinalAlias; @@ -2204,10 +2469,18 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; } /* else if (auto sliceAttr = dyn_cast(attr)) { os << "slice"; return AliasResult::FinalAlias; } */ + // Memory space attributes + if (auto smem = mlir::dyn_cast(attr)) { + os << "smem"; + return AliasResult::FinalAlias; + } return OpAsmDialectInterface::getAlias(attr, os); } }; @@ -2736,6 +3009,68 @@ struct TritonGPUInferLayoutInterface } }; +struct TritonGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, ModuleOp module, + function_ref makeErr) const override { + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = + triton::gpu::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = + triton::gpu::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Canonicalizer //===----------------------------------------------------------------------===// @@ -3275,10 +3610,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, if (!layout) return ""; - unsigned threadsPerWarp = getWarpSize(layout); - unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout); - unsigned numBlocks = getNumCTAs(layout); - int numElementsPerThreads = getTotalElemsPerThread(tensorType); StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); @@ -3291,6 +3622,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, int64_t tensorSize = product(tensorType.getShape()); std::vector elementMapping(tensorSize); std::vector threadMapping; + unsigned threadsPerWarp = ll->getInDimSize(kLane); + unsigned numWarpsPerCTA = ll->getInDimSize(kWarp); + unsigned numBlocks = ll->getInDimSize(kBlock); + int numElementsPerThreads = ll->getInDimSize(kRegister); for (int blockId = 0; blockId < numBlocks; ++blockId) { for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { for (int tid = 0; tid < threadsPerWarp; ++tid) { @@ -3421,12 +3756,52 @@ void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) { llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/true); } +struct TensorModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + void TritonGPUDialect::initialize() { registerTypes(); addAttributes< #define GET_ATTRDEF_LIST -#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" >(); addOperations< #define GET_OP_LIST @@ -3435,6 +3810,10 @@ void TritonGPUDialect::initialize() { >(); addInterfaces(); addInterfaces(); + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); } // verify TritonGPU ops diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 9bc3be036c4e..9f6bc4d61f51 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -32,47 +32,13 @@ namespace { #define S(v) StringAttr::get(ctx, (v)) -// Returns ["out0", "out1", ..., "out"]. -SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); SmallVector ret; - for (int i = 0; i < rank; i++) { - ret.push_back(S("dim" + llvm::Twine(i))); - } - return ret; -} - -void assertIsRegisterLayout(const LinearLayout &layout) { - assert(layout.getNumInDims() > 0); - MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); - StringAttr kRegister = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kBlock = S("block"); - - const auto &ins = layout.getInDimNames(); - assert(llvm::SmallVector(ins.begin(), ins.end()) == - llvm::SmallVector({kRegister, kLane, kWarp, kBlock})); - - const auto &outs = layout.getOutDimNames(); - const auto &expectedOuts = standardOutDimNames(ctx, layout.getNumOutDims()); - assert(llvm::SmallDenseSet(outs.begin(), outs.end()) == - llvm::SmallDenseSet(expectedOuts.begin(), - expectedOuts.end())); -} - -// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of -// size product(shape) and then reshaping to permute(shape, order). -LinearLayout identityND(StringAttr inDimName, ArrayRef shape, - ArrayRef order, - ArrayRef outDimNames) { - assert(shape.size() == order.size()); - - MLIRContext *ctx = inDimName.getContext(); - LinearLayout ret = LinearLayout::empty(); - for (int i = 0; i < shape.size(); i++) { - // Start with the most-minor dimension, which is order[0]. - int dim = order[i]; - ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + for (unsigned i : order) { + ret.push_back(names[i]); } return ret; } @@ -106,124 +72,6 @@ LinearLayout makeCgaLayout(CTALayoutAttr layout) { return ret.transposeOuts(outDimNames); } -// For each output dimension d, ensure that the layout's output size (i.e., its -// codomain) does not exceed shape[d]. Do this without changing the size of the -// layout's inputs (i.e., leave its domain unchanged). -// -// This function is invariant to the order of the layout's input and output -// dimensions. -// -// We achieve this by setting the largest value in each output dimension d to 0 -// because bases that map to a location larger than shape[d] -// effectively duplicate along that dimension. For example, consider a layout -// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to -// shrink the output dimension size to 8: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 16 -// -// In the first step, we shrink the output dimension size to 16 by setting -// L(lane=2) to 0: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// This means that lane=2 has the same data as lane=0. -// -// Now the output dimension of this layout has a size of 16, which is still -// larger than 8. We find the current largest value in the output dimension, -// which is L(register=1) = 8, and we set L(register=1) to 0: -// -// L(register=1) = 0 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// Now the output dimension of this layout has a size of 8, which is the desired -// size. Note that this method works only because the bases are powers of two. -// It is unclear what to do when they are not. -LinearLayout ensureLayoutNotLargerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - MLIRContext *ctx = shape.begin()->first.getContext(); - - auto bases = layout.getBases(); - for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { - auto outDimName = outDim.value(); - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - if (actualSize <= desiredSize) { - continue; - } - assert(actualSize % desiredSize == 0); - // - std::vector> sortedBases; - for (auto [inDimName, basis] : bases) { - for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { - auto outValue = basis[basisIdx][outDim.index()]; - if (outValue == 0) { - continue; - } - assert(llvm::isPowerOf2_32(outValue)); - sortedBases.emplace_back(inDimName, basisIdx, outValue); - } - } - // From the largest basis to the smallest. - llvm::sort(sortedBases, - [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); - for (auto [inDimName, basisIdx, outValue] : sortedBases) { - if (actualSize <= desiredSize) { - break; - } - bases[inDimName][basisIdx][outDim.index()] = 0; - actualSize >>= 1; - } - } - return LinearLayout(std::move(bases), - llvm::to_vector(layout.getOutDimNames())); -} - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along its most-minor dimension ("register" for register layouts, "offset" for -// shared layouts). -// -// This function is invariant to the order of the layout's input dimensions, but -// it cares about the order of the output dims, which should be minor-to-major. -LinearLayout ensureLayoutNotSmallerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - - MLIRContext *ctx = shape.begin()->first.getContext(); - StringAttr kDim = *layout.getInDimNames().begin(); - assert(kDim == "register" || kDim == "offset"); - - LinearLayout ret = layout; - for (StringAttr outDimName : layout.getOutDimNames()) { - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - assert(actualSize > desiredSize || desiredSize % actualSize == 0); - ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); - assert(ret.getOutDimSize(outDimName) >= desiredSize); - } - return ret; -} - // Combines the layout of a CTA (input dims [register, lane, warp]) with the // layout of a CGA (i.e. a block), and ensures that the resulting layout has the // given shape. @@ -269,70 +117,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, return ret; } -LinearLayout ampereMmaToLinearLayout(ArrayRef shape, - NvidiaMmaEncodingAttr mma) { - int rank = shape.size(); - - assert(mma.isAmpere()); - assert(rank == 2 || rank == 3); - assert(mma.getInstrShape().size() == rank); - assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || - (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); - - MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); - - LinearLayout ctaLayout( - {{S("register"), {{1, 0}, {0, 8}}}, - {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - ctaLayout *= identityND( - S("warp"), mma.getWarpsPerCTA(), - llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); - - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); -} - -LinearLayout hopperMmaToLinearLayout(ArrayRef shape, - NvidiaMmaEncodingAttr mma) { - int rank = shape.size(); - assert(mma.isHopper()); - assert(rank == 2); - - // wgmma operates on groups of 4 warps. - assert(product(mma.getWarpsPerCTA()) % 4 == 0); - - // Check that it's a known MMA layout. - assert(mma.getInstrShape().size() == 3); - int m = mma.getInstrShape()[0]; - int n = mma.getInstrShape()[1]; - int k = mma.getInstrShape()[2]; - assert(m == 16); - assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); - assert(k == 8 || k == 16 || k == 32); - - MLIRContext *ctx = mma.getContext(); - LinearLayout ctaLayout( - {{S("register"), {{1, 0}, {0, 8}}}, - {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - {S("dim1"), S("dim0")}); - - // Expand the `register` dimension so the size of dim1 matches `n`. - ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), - S("register"), S("dim1")); - - // Expand the `warp` dimension according to warpsPerCTA. - // - // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but - // this really does seem to be correct. - ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, - {S("dim0"), S("dim1")}) - .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); - - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); -} - LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, SharedEncodingAttr shared) { assert(!shared.getHasLeadingOffset()); @@ -427,12 +211,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, llvm::report_fatal_error("Illegal shared layout"); } - int vec = 8 * 16 / elemBitWidth; - if (vec != shared.getVec()) { - llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec - << ": " << shared << "\n"; - llvm::report_fatal_error("Illegal shared layout"); - } + int vec = shared.getVec(); StringAttr colDimName = outDimNames[colDim]; StringAttr rowDimName = outDimNames[rowDim]; @@ -544,15 +323,15 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } std::optional -dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -654,8 +433,7 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); } - LinearLayout warpLayout = - identityND(kWarp, warpsPerCTA, warpOrder, outDimNames); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * warpLayout.transposeOuts(outDimNames); @@ -737,7 +515,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -746,27 +524,158 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { assert(shape.size() == getOrder().size()); - - int rank = shape.size(); MLIRContext *ctx = getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); const auto &order = getOrder(); LinearLayout ctaLayout = - identityND(S("register"), getSizePerThread(), order, outDimNames) * - identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) * - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder) { + // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + // Like LinearLayout::empty() but with a rank and an order + int rank = repOrder.size(); + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout ctaLayout = + identityStandardND(S("register"), trivialShape, repOrder); + + assert(rank >= 2); + auto inner = order[0]; + auto outer = order[1]; + + assert(tileShape.size() == rank); + int m = tileShape[outer]; + int n = tileShape[inner]; + + // The relative order of registers and lanes is given by: + // - Inner dim: kWidth registers + // - Inner dim: 4 lanes + // - Outer dim: 8 lanes + // - Outer dim: repeat m / 8 times + // - Inner dim: repeat n / (kWidth * 4) times + assert(m % 8 == 0); + assert(n % (kWidth * 4) == 0); + // There is at least one subtile on the inner-most dimension + // FIXME. We should implement operator* in terms of operator*= + // and chain *= instead of using * + auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames()); + ctaLayout = ctaLayout * + LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) * + LinearLayout::identity1D(4, S("lane"), dimNames[inner]) * + LinearLayout::identity1D(8, S("lane"), dimNames[outer]) * + LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) * + LinearLayout::identity1D(n / (kWidth * 4), S("register"), + dimNames[inner]); + return ctaLayout; +} + std::optional NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ctx = getContext(); + int rank = shape.size(); + + SmallVector tileShape; if (isAmpere()) { - return ampereMmaToLinearLayout(shape, *this); + // Ampere.getInstrShape() returns the tile shape + tileShape = SmallVector(getInstrShape()); + } else { + assert(isHopper()); + auto instrShapeMNK = getInstrShape(); + tileShape = SmallVector({instrShapeMNK[0], instrShapeMNK[1]}); + } + // nvidiamma layout always assumes kWidth = 2 + constexpr auto kWidth = 2; + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(*this), getRepOrder()); + + // The triton orders are defined on [dim0, dim1, ...], so we need to pass + // those dims Then, for some reason, operator* requires the orders to match + // so we need to reorder the outs to match + ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder()) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef mmaWarpShape, + ArrayRef mmaWarpOrder, bool isA) { + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {1, 0} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In other words, we need to broadcast along K + auto rank = mmaWarpOrder.size(); + auto inner = isA ? rank - 1 : rank - 2; + auto dimNames = standardOutDimNames(ctx, rank); + LinearLayout warpLayout = LinearLayout::empty(); + + // We have to broadcast along the inner dimension + // For A, when moving along M we go from 0 to 2. + // For B, when moving along N we go from 0 to 1. + // As such, choosing the order of A {1, 0}, gives us the correct broadcasting + // Same happens if the mmaWarpOrder is {0, 1}, like in Hopper + for (auto d : mmaWarpOrder) { + if (d == inner) { + warpLayout *= + LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]); + } else { + warpLayout *= + LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]); + } } - if (isHopper()) { - return hopperMmaToLinearLayout(shape, *this); + return warpLayout; +} + +LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + MLIRContext *ctx = mma.getContext(); + + SmallVector tileShape(rank, 1); + if (isA) { + tileShape[rank - 2] = 16; + tileShape[rank - 1] = kWidth * 8; + } else { + // Hopper takes the rhs via shared memory + assert(mma.isAmpere()); + tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 1] = 8; + } + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder()); + ctaLayout *= + warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); +} + +std::optional +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto parent = getParent(); + if (auto mfmaLayout = llvm::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(parent)) { + return nvidiaDotToLinearLayout(shape, *this); } return std::nullopt; } @@ -841,88 +750,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { return ret; } -LinearLayout ampereDotToLinearLayout(ArrayRef shape, - DotOperandEncodingAttr dot) { - // TODO,BE. Implement ampereMMA in terms of this one - int rank = shape.size(); - auto mma = cast(dot.getParent()); - int kWidth = dot.getKWidth(); - bool isA = dot.getOpIdx() == 0; - - assert(mma.isAmpere()); - assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || - (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); - - MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); - - // Implement A. For B transpose in the end - std::vector> registers; - std::vector> lanes; - int32_t i = 1; - // kWidth contiguous elements - while (i < kWidth) { - registers.push_back({i, 0}); - i *= 2; - } - // 4 threads per chunk - for (int j = 0; j < 2; j++) { - lanes.push_back({i, 0}); - i *= 2; - } - // 8 threads going down - lanes.push_back({0, 1}); - lanes.push_back({0, 2}); - lanes.push_back({0, 4}); - // 2 tiles in column-major order - // Just one if it's the B operand - if (isA) { - registers.push_back({0, 8}); - } - registers.push_back({i, 0}); - - if (!isA) { - for (auto &r : registers) { - std::swap(r[0], r[1]); - } - for (auto &l : lanes) { - std::swap(l[0], l[1]); - } - } - - LinearLayout ctaLayout( - {{S("register"), registers}, {S("lane"), lanes}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - auto order = dot.getCTAOrder(); - assert(order[0] == 1 && order[1] == 0); - ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); - - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); -} - -std::optional -DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { - return dotOperandMfmaToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(getParent())) { - // FIXME [Dot LL] - // Do this unconditionally - auto largeKWidth = getKWidth() == 8; - if (mma.isAmpere() && largeKWidth) { - return ampereDotToLinearLayout(shape, *this); - } - } - return std::nullopt; -} - std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { + // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { return distributed.toLinearLayout(shape); - } - if (auto shared = dyn_cast(layout)) { + } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); @@ -931,29 +765,10 @@ toLinearLayout(ArrayRef shape, Attribute layout, } } - // TODO(jlebar): Other layouts + // Third party layouts return std::nullopt; } -bool isCrossCTAConversion(const LinearLayout &layout) { - assert(!layout.getInDimNames().empty()); - MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); - - StringAttr kBlock = S("block"); - assert(layout.hasInDim(kBlock)); - assert(layout.hasOutDim(kBlock)); - - SetVector nonBlockInDims(layout.getInDimNames().begin(), - layout.getInDimNames().end()); - nonBlockInDims.remove(kBlock); - - // This layout moves data between CTAs if - // - the value for any input dim other than block affects the output block, or - // - input (0, ..., block=i) does not map to output (0, ..., block=i). - return !layout.sublayoutIsZero(nonBlockInDims.getArrayRef(), {kBlock}) || - !layout.sublayoutIsIdentity({kBlock}, {kBlock}); -} - LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { assert(!layout.getInDimNames().empty()); MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); @@ -1005,40 +820,7 @@ LinearLayout chooseShemLayoutForRegToRegConversion( } namespace { - -// TODO (Keren): Currently, we have more restrictions than necessary when using -// stmatrix. These restrictions are retained from legacy code, and we could -// relax some of them in the future. -bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, - ArrayRef paddedRepShape, ArrayRef order, - int swizzleByteSize) { - auto mmaLayout = - mlir::dyn_cast(tensorTy.getEncoding()); - if (!mmaLayout || !mmaLayout.isHopper()) - return false; - if (isa(tensorTy.getElementType())) - return false; - if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) - return false; - if (order[0] != 1) - return false; - - auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); - if (tensorShapePerCTA.size() != 2) - return false; - auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * - ceil(tensorShapePerCTA[0], repShape[0]); - if (numIterations > 1) - return false; - if (paddedRepShape[1] % 8 != 0) - return false; - if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && - swizzleByteSize != 128) - return false; - return true; -} - -std::optional chooseStMatrixLayoutLeadingOffset( +LinearLayout chooseStMatrixLayoutLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order, int swizzleByteSize) { @@ -1093,9 +875,8 @@ std::optional chooseStMatrixLayoutLeadingOffset( // Expand the `warp` dimension according to warpsPerCTA. auto mma = cast(tensorTy.getEncoding()); - layout *= - identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); // Expand the `register` dimension so the size of columns matches `n`. int n = mma.getInstrShape()[1]; @@ -1110,7 +891,7 @@ std::optional chooseStMatrixLayoutLeadingOffset( .reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}}); } -std::optional chooseStMatrixLayoutNoLeadingOffset( +LinearLayout chooseStMatrixLayoutNoLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order) { StringAttr kReg = S("register"); @@ -1133,9 +914,8 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); // Expand the `warp` dimension according to warpsPerCTA. - layout *= - identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape()); @@ -1145,22 +925,17 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( ret = ensureLayoutNotSmallerThan(ret, namedTensorShape); ret = ensureLayoutNotLargerThan(ret, namedTensorShape); return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) - .reshapeOuts({{S("offset"), ret.getTotalOutDimSize()}, - {S("iteration"), 1}}) * - identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")}); + .reshapeOuts( + {{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}}); } } // anonymous namespace -std::optional -chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, - ArrayRef repShape, - ArrayRef paddedRepShape, - ArrayRef order, int swizzleByteSize) { - if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order, - swizzleByteSize)) - return std::nullopt; - +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) { if (swizzleByteSize == 0) return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape, paddedRepShape, order); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e61fe096e10b..068965468eeb 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -1,10 +1,8 @@ #include "mlir/IR/BuiltinTypes.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" -#include "llvm/Support/raw_ostream.h" #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" @@ -34,39 +32,63 @@ LogicalResult UpcastMXFPOp::verify() { "operands must have the same number of dimensions, at least 2"); } - if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 || - fpType == F8F6F4Type::E5M2)) { + if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) { return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); } - // Change to support fp8 types - const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1; - - if (xShape.back() != (32 / elems_packed) * scaleShape.back()) { - return emitOpError("last dimension of first operand must be 16 times " - "larger than that of the second operand"); + auto layoutX = xTy.getEncoding(); + auto layoutScale = scaleTy.getEncoding(); + if (bool(layoutX) != bool(layoutScale)) { + return emitOpError( + "Expected either both or neither operands to have an encoding"); + } + // Nothing to check if no encoding. This is used to infer the return type in + // AccelerateMatmul.cpp + if (!layoutX) { + return success(); } - if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) { + auto dotEncoding = dyn_cast(layoutX); + if (!dotEncoding) { + return emitOpError("Expected a DotOperandEncodingAttr for values"); + } + if (!isa(layoutScale)) { return emitOpError( - "all dimensions except the last must match between operands"); + "Expected a BlockOperandEncoding or LinearOperandEncoding " + "for scales"); } - auto layoutX = xTy.getEncoding(); - if (!layoutX || !isa(layoutX)) { - return emitOpError("Expected a DotOperandEncodingAttr for values"); + if (isa(dotEncoding.getParent())) { + // Necessary to keep all of the scales of a given block of values in the + // same warp + auto threadsPerWarp = + cast(layoutScale).getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } } - auto layoutScale = scaleTy.getEncoding(); - if (!layoutScale || !isa(layoutScale)) { - return emitOpError("Expected a BlockOperandEncoding for scales"); + + // Change to support fp8 types + const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1; + // Figure out the K dimension for the input A/B. For A/B scale, the K + // dimension is always the last dimension. + const int opIdx = dotEncoding.getOpIdx(); + const bool hasBatch = xShape.size() == 3; + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; + + if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) { + return emitOpError("K dimension of first operand must be 16 times " + "larger than last/K dimension of the second operand"); } - auto blockedScale = cast(layoutScale); - // Necessary to keep all of the scales of a given block of values in the same - // warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); - if (threadsPerWarp != ArrayRef({16, 2})) { - return emitOpError("Expected threads per warp to be {16, 2}"); + // Check other dimensions match too. For input A/B, we need to figure out the + // index for the M/N dimension. For scale, it's always {(batch), M/N, K}. + const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch; + if (hasBatch && xShape[0] != scaleShape[0]) + return emitOpError("batch dimension must match between operands"); + if (xShape[mnIdx] != scaleShape[hasBatch]) { + return emitOpError("M/N dimension must match between operands"); } return success(); @@ -82,22 +104,29 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( auto xShape = xTy.getShape(); auto encoding = xTy.getEncoding(); - if (!encoding) { - return emitOptionalError(loc, "expected an encoding"); - } - if (!mlir::isa(encoding)) { - return emitOptionalError(loc, "expected a dotOperand encoding"); - } - if (typeEncoded == F8F6F4Type::E2M1) { - auto oldEncoding = cast(encoding); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), - oldEncoding.getKWidth() * 2); + if (typeEncoded == ScaleDotElemType::E2M1) { + RankedTensorType retTy; + auto newShape = SmallVector(xShape); - newShape.back() *= 2; - inferredReturnTypes.push_back( - RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding)); + if (!encoding) { + newShape.back() *= 2; + retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); + } else { + auto oldEncoding = cast(encoding); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + // Figure out the K dimension for the input A/B, given that the return + // type is upcasted A/B type so we need to update the proper dim size. + const int opIdx = oldEncoding.getOpIdx(); + const bool hasBatch = xShape.size() == 3; + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; + newShape[kIdx] *= 2; + retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx), + newVEncoding); + } + inferredReturnTypes.push_back(retTy); } else { inferredReturnTypes.push_back(xTy); } @@ -105,4 +134,48 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( return success(); } +OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult MemDescTransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + auto memDescTy = cast(argTy); + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), + memDescTy.getMutableMemory())); + return success(); +} + } // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/IR/Types.cpp b/lib/Dialect/TritonGPU/IR/Types.cpp index 77f673cc2766..ef9c6c4a3067 100644 --- a/lib/Dialect/TritonGPU/IR/Types.cpp +++ b/lib/Dialect/TritonGPU/IR/Types.cpp @@ -27,6 +27,85 @@ void TokenType::print(AsmPrinter &printer) const { printer << "<" << getType() << ">"; } +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + SmallVector dimensions; // required + if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false))) + return Type(); + + Type elementType; // required + if (failed(parser.parseType(elementType))) + return Type(); + + Attribute encoding; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding))) + return Type(); + + Attribute memorySpace; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace))) + return Type(); + + bool mutableMemory = false; // optional + SmallVector allocShape; // optional + if (succeeded(parser.parseOptionalComma())) { + if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) { + mutableMemory = true; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + } else if (failed(parser.parseDimensionList(allocShape, + /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + + if (parser.parseGreater()) + return Type(); + + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, memorySpace, mutableMemory, dimensions); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + auto shape = getShape(); + for (auto dim : shape) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + auto allocShape = getAllocShape(); + if (allocShape != shape) { + printer << ", " << allocShape[0]; + for (auto dim : allocShape.drop_front(1)) { + printer << "x" << dim; + } + } + printer << ">"; +} + +LogicalResult MemDescType::verify(function_ref emitError, + ArrayRef shape, Type elementType, + Attribute encoding, Attribute memorySpace, + bool mutableMemory, + ArrayRef allocShape) { + if (allocShape.size() < shape.size()) + emitError() << "alloc shape must have at least as many dimensions as shape"; + return success(); +} + //===----------------------------------------------------------------------===// // Triton Dialect //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a2d4012bf23e..3b29f73e1d7a 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -7,10 +7,12 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -77,28 +79,33 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } } - SmallVector ret(rank, 1); - SmallVector shapePerWarp(rank, 1); - shapePerWarp[rank - 1] = 8; - shapePerWarp[rank - 2] = 16; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; + assert(rank == 2); + SmallVector shapePerWarp = {16, 8}; + SmallVector warps = {1, 1}; + // Compute repM and repN + SmallVector reps = {ceil(shape[0], shapePerWarp[0]), + ceil(shape[1], shapePerWarp[1])}; + // The formula for the number of registers given the reps is + // repM * 4 * repK + repN * 2 * repK + regsC + // where regsC = repM * repN * 4, which does not depend on the warp shape + // + // As such, to minimize the register pressure, we need to balance + // repM and repN. We then untie towards M, as the lhs tile has 4 elements, + // and the rhs tile has just 2. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + // Too many warps for this mma (repM == repN == 1). + // We allocate the remainin warps to the left (arbitrary choice) + if (reps[0] != 1) { + reps[0] /= 2; + } } else { - ret[1] *= 2; + warps[1] *= 2; + reps[1] /= 2; } - } while (true); - return ret; + } + return {(unsigned)warps[0], (unsigned)warps[1]}; } SmallVector @@ -106,8 +113,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) + // Contains a chained dot. We prefer to assign warps to one axis + // to facilitate use cases like flash attention, allowing reductions within + // the same warp. + if (llvm::find_if(slices, [](Operation *op) { + return op->hasTrait(); + }) != slices.end()) return {(unsigned)numWarps, 1}; // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). @@ -159,9 +170,22 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, return rewriter.create(arg.getLoc(), newType, arg); } +SmallVector +getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } +} + class BlockedToMMA : public mlir::OpRewritePattern { int computeCapability; - mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { @@ -183,7 +207,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { // elements distribution to the order of higher precision primitives. As a // result, kwidth can be the bitwidth of the lower precision primitive. // Conversely, in the downcasting scenario, no reordering is performed, - // making it directory use the lower precision primitive. + // making it directly use the lower precision primitive. static int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; @@ -210,25 +234,17 @@ class BlockedToMMA : public mlir::OpRewritePattern { : OpRewritePattern(context), computeCapability(computeCapability) { } - static SmallVector - getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, - int numWarps, const SmallVector &instrShape) { - switch (version) { - case 2: - return warpsPerTileV2(dotOp, shape, numWarps); - case 3: - return warpsPerTileV3(dotOp, shape, numWarps, instrShape); - default: - assert(false && "not supported version"); - return {0, 0}; - } - } - mlir::LogicalResult matchAndRewrite(triton::DotOp dotOp, mlir::PatternRewriter &rewriter) const override { if (computeCapability < 70) return failure(); + if (computeCapability < 80) { + dotOp.emitRemark() + << "Dot op using MMA for compute capability " << computeCapability + << " has been deprecated. It falls back to the FMA path."; + return failure(); + } // TODO: Check data-types and SM compatibility RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || @@ -254,47 +270,13 @@ class BlockedToMMA : public mlir::OpRewritePattern { auto oldAType = dotOp.getA().getType(); auto oldBType = dotOp.getB().getType(); - NvidiaMmaEncodingAttr mmaEnc; - if (versionMajor == 1) { - SetVector aBwdSlices, bBwdSlices; - auto isCvt = [](Operation *op) { return isa(op); }; - mlir::BackwardSliceOptions opt; - opt.omitBlockArguments = true; - opt.filter = isCvt; - getBackwardSlice(a, &aBwdSlices, opt); - getBackwardSlice(b, &bBwdSlices, opt); - // get the source of the first conversion found in slices - auto getCvtArgOrder = [](Operation *op) { - return mlir::cast( - cast(op).getSrc().getType().getEncoding()) - .getOrder(); - }; - bool isARow = true; - bool isBRow = true; - Operation *aOp = a.getDefiningOp(); - Operation *bOp = b.getDefiningOp(); - if (!aBwdSlices.empty()) - aOp = aBwdSlices[0]; - if (!bBwdSlices.empty()) - bOp = bBwdSlices[0]; - if (aOp) - isARow = getCvtArgOrder(aOp)[0] == 1; - if (bOp) - isBRow = getCvtArgOrder(bOp)[0] == 1; - - mmaEnc = NvidiaMmaEncodingAttr::get( - oldRetType.getContext(), versionMajor, numWarps, CTALayout, - instrShape, oldAType.getShape(), oldBType.getShape(), retShapePerCTA, - isARow, isBRow, mmaV1Counter++); - } else { - assert(versionMajor == 2 || versionMajor == 3); - int versionMinor = computeCapability == 75 ? 1 : 0; - auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, - numWarps, instrShape); - mmaEnc = NvidiaMmaEncodingAttr::get(oldRetType.getContext(), versionMajor, - versionMinor, warpsPerTile, CTALayout, - instrShape); - } + assert(versionMajor == 2 || versionMajor == 3); + int versionMinor = computeCapability == 75 ? 1 : 0; + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + auto mmaEnc = NvidiaMmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, versionMinor, warpsPerTile, + CTALayout, instrShape); auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator @@ -384,154 +366,416 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { }); } -class ScaledBlockedToMMAv2 +class DecomposeScaledBlocked : public mlir::OpRewritePattern { int computeCapability; public: - ScaledBlockedToMMAv2(mlir::MLIRContext *context, int computeCapability) + DecomposeScaledBlocked(mlir::MLIRContext *context, int computeCapability) : mlir::OpRewritePattern(context), computeCapability(computeCapability) {} mlir::LogicalResult - matchAndRewrite(triton::DotScaledOp dotOp, + matchAndRewrite(triton::DotScaledOp scaledDotOp, mlir::PatternRewriter &rewriter) const override { - if (computeCapability >= 100) - return failure(); + if (computeCapability >= 100 || computeCapability < 80) + return rewriter.notifyMatchFailure( + scaledDotOp, "DotScaledOp just supported on Ampere and Hopper"); - auto oldRetType = dotOp.getType(); + auto oldRetType = scaledDotOp.getType(); if (!oldRetType.getEncoding() || mlir::isa(oldRetType.getEncoding())) return failure(); - auto ctx = dotOp.getContext(); + + auto ctx = scaledDotOp.getContext(); // Check that rhs scale is null - assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI"); + assert(scaledDotOp.getRhsScale() == nullptr && "rhs scale NYI"); // operands - auto a = dotOp.getLhs(); - auto b = dotOp.getRhs(); - auto scale = dotOp.getLhsScale(); - auto aType = dotOp.getLhsType(); - auto bType = dotOp.getRhsType(); - - auto enumToType = [&rewriter](F8F6F4Type type) { - switch (type) { - case F8F6F4Type::E4M3: - return rewriter.getFloat8E4M3FNType(); - case F8F6F4Type::E5M2: - return rewriter.getFloat8E5M2Type(); - default: - llvm_unreachable("unexpected type"); - } - }; - - assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || - aType == F8F6F4Type::E2M1) && + auto a = scaledDotOp.getLhs(); + auto b = scaledDotOp.getRhs(); + auto scale = scaledDotOp.getLhsScale(); + auto aType = scaledDotOp.getLhsType(); + auto bType = scaledDotOp.getRhsType(); + + auto rank = oldRetType.getShape().size(); + if (rank != 2) + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: rank==3"); + + assert((aType == ScaleDotElemType::E4M3 || + aType == ScaleDotElemType::E5M2 || + aType == ScaleDotElemType::E2M1) && "NYI: lhs supports fp4 or fp8"); - assert(bType == F8F6F4Type::E4M3 || - bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); + assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || + bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); + bool isFp4 = aType == ScaleDotElemType::E2M1; + + auto mmaEnc = getMMAEncoding(rewriter, scaledDotOp); + auto versionMajor = mmaEnc.getVersionMajor(); + assert(versionMajor == 2 || + versionMajor == 3 && "NYI: MMAV2 and MMAV3 only"); - // TODO run accelerate matmul on A and B first to choose their layouts - // Set return type - auto versionMajor = 2; - auto retShapePerCTA = getShapePerCTA(oldRetType); - auto mod = dotOp->getParentOfType(); - unsigned numWarps = TritonGPUDialect::getNumWarps(mod); - auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, - rewriter.getBF16Type(), numWarps); - auto CTALayout = getCTALayout(oldRetType.getEncoding()); - // TODO Use warpsPerTileV2 - SmallVector warpsPerCTA = {numWarps, 1}; - auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor, - /*versionMinor=*/0, warpsPerCTA, - CTALayout, instrShape); auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = scaledDotOp.getC(); auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( - TypedValue v, int idx, - F8F6F4Type type) -> TypedValue { - auto vType = v.getType(); - if (type == F8F6F4Type::E2M1) { - // A bit too dynamically typed... - // perhaps return ints in both cases? - - auto retEnc = dyn_cast(newRetType.getEncoding()); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, idx, newRetType.getEncoding(), /*kWidth=*/4); - auto newVType = RankedTensorType::get( - vType.getShape(), vType.getElementType(), newVEncoding); - return rewriter.create(v.getLoc(), newVType, v); - } else { - assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); - auto newVType = RankedTensorType::get( - vType.getShape(), vType.getElementType(), newVEncoding); - v = rewriter.create(v.getLoc(), newVType, v); - - // Bitcast - auto vTypeFp8 = RankedTensorType::get(vType.getShape(), - enumToType(type), newVEncoding); - v = cast>( - rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - - // Convert to bf16 - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); - return rewriter.create(v.getLoc(), vTypeBf16, v); - } - }; - a = toMMABf16(a, 0, aType); - b = toMMABf16(b, 1, bType); - + // TODO: This should be kWidth = 2 once MMAv2 supports kWidth=1 for 1 byte + // types + auto aKWidth = mmaEnc.isHopper() ? 2 : 8; + auto bKWidth = mmaEnc.isHopper() ? 2 : 8; + if (isFp4) { + // Load 2x4-bit elements per thread + aKWidth /= 2; + } // [Note: A trick to avoid warp shuffles in the lowering] - // FIXME: Implement this when we can set general layouts on a tensor - - // For bf16, we have 4 threads per row - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-a-f16 - // and each of them needs to get every scale in that row. - // It turns out that the layout for the output of type bf16 gives us exactly - // this layout when the number of mxfp vectors is equal to two (K = 64) - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c - // This can be generalised to other K with linear layouts, but the general - // layout cannot cannot be represented with the predefined layouts :( - // With this trick, we could do the full lowering here and remove the - // UpcastMXFPOp altogether - - assert(instrShape == ArrayRef({16, 8}) || - instrShape == ArrayRef({1, 16, 8})); - auto shapeTileA = std::array{instrShape[0], instrShape[0]}; + // Once we fully support LLs in the IR, we can craft an LL so that + // broadcasting happens effectively in the convertLayoutOp lowering. For + // this, we would just need to create an LL with + // `bases[warps] = {(0, 0), (0, 0), ...}` + + auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, aKWidth); + + // MMAv3 uses the first dimension for the M dimension, while MMAv2 uses the + // penultimate (ugh) + auto instrShapeM = + mmaEnc.getInstrShape()[versionMajor == 3 + ? 0 + : mmaEnc.getInstrShape().size() - 2]; + auto warpSize = getWarpSize(newAEncoding); + assert(instrShapeM <= warpSize); // Necessary choice to leave all the scales of the tile in that given warp auto threadsPerWarp = - SmallVector{shapeTileA[0], 32 / shapeTileA[0]}; + SmallVector{instrShapeM, warpSize / instrShapeM}; + + // This has to align with the order in UpcastMXFPOp + auto order = getMatrixOrder(rank, /*rowMajor=*/true); + Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), order, + mmaEnc.getCTALayout()); + + // Lezcano: In the future we could just use the LLs unconditionally + // Not doing it now as they are not as performant as Blocked encoding at + // times E.g., we bail on them in the backwardMaterialization pass + auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1; + if (dotBroadcastsWarpLevel) { + auto kRegister = StringAttr::get(ctx, "register"); + auto regs = identityStandardND(kRegister, {1, 1}, order); + auto lanes = + identityStandardND(StringAttr::get(ctx, "lane"), {16, 2}, order); + + // Extract warp layout from dotAEncoding + // In the future we'll have some nice division utils, but until then... + auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape()); + LinearLayout::BasesT scaleBases = dotLL.getBases(); + auto kWarp = StringAttr::get(ctx, "warp"); + auto &warpBases = scaleBases[kWarp]; + // The tile shape was [16, 2 * 4 * kWidth] with broadcasting in K + // We divide the M dimension by 16 + auto div = 16; + for (auto &warpBase : warpBases) { + if (warpBase[rank - 2] != 0) { + assert(warpBase[rank - 2] % div == 0); + warpBase[rank - 2] /= div; + } + } - auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( - ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); + LinearLayout::BasesT warpBlockBases; + auto standardOutDims = llvm::to_vector(dotLL.getOutDimNames()); + warpBlockBases[kWarp] = warpBases; + auto kBlock = StringAttr::get(ctx, "block"); + assert(scaleBases[kBlock].empty() && "NYI: CGAs"); + warpBlockBases[kBlock] = {}; + auto warpBlock = LinearLayout(std::move(warpBlockBases), standardOutDims); + + auto newLL = + (regs * lanes) * + warpBlock.transposeOuts(llvm::to_vector(lanes.getOutDimNames())); + auto shape = scale.getType().getShape(); + + // Broadcast to the correct shape Equivalent to + // newLL = ensureLayoutNotSmallerThan(newLL.transposeOuts(getRepOrder), + // shape); + for (auto d : newAEncoding.getRepOrder()) { + auto outDim = standardOutDims[d]; + auto dimSize = newLL.getOutDimSize(outDim); + newLL *= + LinearLayout::identity1D(shape[d] / dimSize, kRegister, outDim); + } + newLL = newLL.transposeOuts(standardOutDims); + newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); + } - auto newScaleType = RankedTensorType::get(scale.getType().getShape(), - scale.getType().getElementType(), - newScaleEncoding); - scale = - rewriter.create(scale.getLoc(), newScaleType, scale); + a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); - auto scaledA = rewriter.create( - dotOp.getLoc(), a, scale, dotOp.getLhsType()); + Operation *newDot = nullptr; + if (versionMajor == 2) { + // Upcast B operand + assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); + auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth); + b = createArg(rewriter, b, 1, bType, newBEncoding, + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); + newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, a, b, + newAcc); + } else { + assert(versionMajor == 3); + // At the time of this writing, this is always true + auto allowTranspose = b.getType().getElementType().isBF16(); + auto bShmem = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); + newDot = rewriter.create( + scaledDotOp.getLoc(), newRetType, a, bShmem, newAcc, nullptr); + } // convert dot instruction - auto newDot = - rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); - rewriter.replaceOpWithNewOp(dotOp, oldRetType, newDot); + rewriter.replaceOpWithNewOp(scaledDotOp, oldRetType, + newDot->getResult(0)); return success(); } + +private: + TypedValue + createArg(mlir::PatternRewriter &rewriter, TypedValue v, + int idx, ScaleDotElemType type, std::optional vEncoding, + std::optional> opt_scale, + std::optional scaleEncoding) const { + auto ctx = rewriter.getContext(); + // Create a new tensor with a given encoding or remove the encoding + auto maybeWithEncoding = + [](RankedTensorType ty, + std::optional enc) -> RankedTensorType { + if (enc.has_value()) { + return RankedTensorType::get(ty.getShape(), ty.getElementType(), *enc); + } else { + return RankedTensorType::get(ty.getShape(), ty.getElementType()); + } + }; + + auto newVType = maybeWithEncoding(v.getType(), vEncoding); + TypedValue ret = + rewriter.create(v.getLoc(), newVType, v); + + // convert to bf16 + if (type != ScaleDotElemType::E2M1 && type != ScaleDotElemType::BF16) { + assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3); + auto vTypeBf16 = RankedTensorType::get( + newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding()); + ret = cast>( + rewriter.create(v.getLoc(), vTypeBf16, ret).getResult()); + } + if (opt_scale.has_value()) { + auto scale = *opt_scale; + assert(idx == 0 && "NYI: rhs scale"); + auto newScaleDotElemType = + maybeWithEncoding(scale.getType(), scaleEncoding); + scale = rewriter.create(scale.getLoc(), + newScaleDotElemType, scale); + ret = rewriter.create(v.getLoc(), ret, scale, + type); + } + return ret; + } + + NvidiaMmaEncodingAttr getMMAEncoding(mlir::PatternRewriter &rewriter, + DotScaledOp scaledDotOp) const { + auto ctx = rewriter.getContext(); + auto a = scaledDotOp.getLhs(); + auto b = scaledDotOp.getRhs(); + auto scale = scaledDotOp.getLhsScale(); + auto aType = scaledDotOp.getLhsType(); + auto bType = scaledDotOp.getRhsType(); + + // create a DotOp to be passed in to getMMAVersionSafe + // We don't pass encodings as we just want to get the type and shape + // to create a DotOp to be passed in to getMMAVersionSafe. We use the + // rewriter to avoid duplicating createArg, but these ops are not going to + // end up in the graph + RankedTensorType aTType = + createArg(rewriter, a, 0, aType, /*vEncoding=*/std::nullopt, scale, + /*scaleEncoding=*/std::nullopt) + .getType(); + auto aTypeNoEnc = + RankedTensorType::get(aTType.getShape(), aTType.getElementType()); + a = rewriter.create(scaledDotOp.getLoc(), aTypeNoEnc, a); + + RankedTensorType bTType = + createArg(rewriter, b, 1, bType, /*vEncoding=*/std::nullopt, + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt) + .getType(); + auto bTypeNoEnc = + RankedTensorType::get(bTType.getShape(), bTType.getElementType()); + b = rewriter.create(scaledDotOp.getLoc(), bTypeNoEnc, b); + auto dotOp = rewriter.create( + scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC()); + + int versionMajor = 2; + // We just support bf16 for MMAv3 on the rhs + if (bType == ScaleDotElemType::BF16) { + versionMajor = getMMAVersionSafe(computeCapability, dotOp); + } + int versionMinor = computeCapability == 75 ? 1 : 0; + + RankedTensorType oldRetType = dotOp.getType(); + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = dotOp->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), + numWarps); + + auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + return NvidiaMmaEncodingAttr::get(ctx, versionMajor, versionMinor, + warpsPerCTA, CTALayout, instrShape); + } }; +static void updateValueType(Value v, Attribute encoding, + ArrayRef shape) { + auto tensorType = cast(v.getType()); + auto newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + v.setType(newType); +} + +static TransOp updateUsers(Value result, const SetVector &slice) { + TransOp transOp; + if (llvm::any_of(result.getUsers(), + [&](Operation *user) { return slice.count(user) == 0; })) { + OpBuilder builder(result.getContext()); + builder.setInsertionPointAfterValue(result); + transOp = + builder.create(result.getLoc(), result, ArrayRef({1, 0})); + result.replaceUsesWithIf(transOp.getResult(), [&](OpOperand &operand) { + return operand.getOwner() != transOp.getOperation() && + slice.count(operand.getOwner()) == 0; + }); + } + return transOp; +} + +// Sync the transpose in the IR, this is done to avoid generating convert layout +// when we have a transpose right after a dot as mma layout cannot be propagated +// through transpose op. Once we have layouts that can represent transposed MMA +// we can remove this transformation. +static void sinkTransposeOp(TransOp input) { + SmallVector queue = {input}; + while (!queue.empty()) { + TransOp transOp = queue.back(); + Value currentValue = transOp.getResult(); + queue.pop_back(); + mlir::ForwardSliceOptions options; + options.filter = [](Operation *op) { + if (op->hasTrait() && op->getNumOperands() == 1) + return true; + if (isa(op)) + return isa(op->getParentOp()); + if (isa(op)) + return true; + return false; + }; + SetVector slice; + mlir::getForwardSlice(currentValue, &slice, options); + for (Operation *op : slice) { + if (op->hasTrait()) { + // Update users of transpose op. + if (op->getOperand(0) == transOp.getResult()) + op->setOperand(0, transOp.getOperand()); + // Update the type of the result. + for (Value result : op->getResults()) { + auto srcType = cast(op->getOperand(0).getType()); + updateValueType(result, srcType.getEncoding(), srcType.getShape()); + updateUsers(result, slice); + } + continue; + } + if (auto cvtOp = dyn_cast(op)) { + // Update users of transpose op. + if (op->getOperand(0) == transOp.getResult()) + op->setOperand(0, transOp.getOperand()); + auto resultEncoding = cvtOp.getType().getEncoding(); + auto newDstEncoding = inferSrcEncoding(transOp, resultEncoding); + auto srcType = cast(cvtOp.getOperand().getType()); + updateValueType(cvtOp.getResult(), *newDstEncoding, srcType.getShape()); + updateUsers(cvtOp.getResult(), slice); + continue; + } + assert(isa(op)); + auto forOp = dyn_cast(op->getParentOp()); + assert(forOp); + for (OpOperand &operand : op->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && (slice.count(def)) || def == transOp.getOperation()) { + if (def == transOp.getOperation()) + operand.set(transOp.getOperand()); + Type newType = operand.get().getType(); + forOp.getResult(operand.getOperandNumber()).setType(newType); + TransOp retTrans = + updateUsers(forOp.getResult(operand.getOperandNumber()), slice); + // Recursively try to propagate the new transpose inserted. + if (retTrans) + queue.push_back(retTrans); + forOp.getRegionIterArg(operand.getOperandNumber()).setType(newType); + TransOp argTrans = updateUsers( + forOp.getRegionIterArg(operand.getOperandNumber()), slice); + if (argTrans) + queue.push_back(argTrans); + OpBuilder builder(forOp); + OpOperand &init = forOp.getInitsMutable()[operand.getOperandNumber()]; + Value initTranspose = builder.create( + forOp.getLoc(), init.get(), ArrayRef({1, 0})); + init.set(initTranspose); + } + } + } + } +} + +// Transpose scaled_dot ops that have a scale on lhs. +static Operation *transposeDotOp(DotScaledOp dotOp) { + OpBuilder builder(dotOp); + Value lhs = dotOp.getLhs(); + std::array transOrder = {1, 0}; + Value lhsTransposed = builder.create(lhs.getLoc(), lhs, transOrder); + Value rhs = dotOp.getRhs(); + Value rhsTransposed = builder.create(rhs.getLoc(), rhs, transOrder); + Value c = dotOp.getC(); + Value cTransposed = builder.create(c.getLoc(), c, transOrder); + Value result = builder.create( + dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed, + cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(), + dotOp.getLhsType()); + Operation *transposedResult = + builder.create(result.getLoc(), result, transOrder); + dotOp.replaceAllUsesWith(transposedResult); + dotOp.erase(); + return transposedResult; +} + +static void transposeDots(ModuleOp m) { + // TODO: extend to regular dot when it is profitable. For instance when we may + // want to use rhs from register for mmav3. + SmallVector toTranspose; + m.walk([&](DotScaledOp dotOp) -> void { + if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr) + toTranspose.push_back(dotOp); + }); + SmallVector transposes; + for (DotScaledOp dotOp : toTranspose) { + Operation *transpose = transposeDotOp(dotOp); + transposes.push_back(transpose); + } + + for (Operation *transpose : transposes) { + sinkTransposeOp(cast(transpose)); + } +} + #define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -547,10 +791,11 @@ class TritonGPUAccelerateMatmulPass ModuleOp m = getOperation(); auto computeCapability = getNVIDIAComputeCapability(m); + transposeDots(m); mlir::RewritePatternSet patterns(context); - patterns.add(context, - computeCapability); + patterns.add(context, + computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 99e2ac3c9660..da176b0fd1a8 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -3,13 +3,17 @@ add_triton_library(TritonGPUTransforms Coalesce.cpp F32DotTC.cpp CombineTensorSelectAndIf.cpp + LoopScheduling.cpp ReduceDataDuplication.cpp OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp + Pipeliner/AssignLatencies.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/OuterLoopPipeline.cpp Pipeliner/PipelineExpander.cpp + Pipeliner/TestPipelineAssignLatencies.cpp + Pipeliner/TestPipelineScheduleLoop.cpp Pipeliner/SoftwarePipeliner.cpp Pipeliner/TMAStoresPipeline.cpp Pipeliner/PipeliningUtility.cpp @@ -17,6 +21,7 @@ add_triton_library(TritonGPUTransforms Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp + CoalesceAsyncCopy.cpp Utility.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp new file mode 100644 index 000000000000..2d634fc6fa7b --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -0,0 +1,124 @@ +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass currently only applies if the following are all true... +// 1) Operand A for WGMMA is to be loaded in registers +// 2) We upcast operand A in registers before the WGMMA +// (downcasting is not yet supported) +// 3) Pipelining is enabled for loading A +// +// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding +// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if +// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread +// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two +// 8-byte-cp.async's for each contiguous 16B global data owned by each +// thread. This breaks coalescing (i.e. results 2x the minimum required +// transactions). +// +// This issue occurs for cp.async because it combines load and store into one +// instruction. The fix is to clip each dim of sizePerThread by shared vec, so +// that the vectorization of load and store are equal along the contiguous +// dimension. In the above example, each thread will then only own 8B contiguous +// global data. +struct ClipAsyncCopySizePerThread + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, + PatternRewriter &rewriter) const override { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(copyOp.getResult().getType()); + auto blockEnc = dyn_cast(srcTy.getEncoding()); + if (!blockEnc) + return rewriter.notifyMatchFailure(copyOp, + "src must be of blocked encoding"); + auto sharedEnc = cast(dstTy.getEncoding()); + auto sharedVec = sharedEnc.getVec(); + + // obtain max contiguous copy size + // Note this can be further optimized, as copyContigSize can be even + // smaller when lowering, depending on contiguity and mask alignment + // (see AsyncCopyGlobalToLocalOpConversion) + auto elemBitWidth = dstTy.getElementTypeBitWidth(); + auto regToSharedLayout = + getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc, + sharedEnc, elemBitWidth); + auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut(); + + // obtain block sizePerThread along contig dim + auto sizePerThread = blockEnc.getSizePerThread(); + auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]]; + + if (blockContigSize <= copyContigSize) + return rewriter.notifyMatchFailure( + copyOp, + "blocked sizePerThread along contiguous dim must be greater than the " + "max contiguous copy size "); + + sizePerThread[blockEnc.getOrder()[0]] = copyContigSize; + + // obtain new blockedEnc based on clipped sizePerThread + auto mod = copyOp->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto newBlockEnc = BlockedEncodingAttr::get( + copyOp.getContext(), srcTy.getShape(), sizePerThread, + blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout()); + + // insert cvt's after src, mask, and other + auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = rewriter.create(copyOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src, newBlockEnc); + if (mask) + mask = convertBlockLayout(mask, newBlockEnc); + if (other) + other = convertBlockLayout(other, newBlockEnc); + + rewriter.modifyOpInPlace(copyOp, [&]() { + copyOp.getSrcMutable().assign(src); + if (mask) + copyOp.getMaskMutable().assign(mask); + if (other) + copyOp.getOtherMutable().assign(other); + }); + + return success(); + } +}; + +class CoalesceAsyncCopyPass + : public impl::TritonGPUCoalesceAsyncCopyBase { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + MLIRContext *context = &getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp index 16183b1af46e..203fe01ba626 100644 --- a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -1,4 +1,5 @@ #include "mlir/IR/Dominance.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -14,8 +15,52 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -// Return true if the select could be merged into the If without breaking SSA -// rules. +/// The user of select maybe inside either the ThenRegion or ElseRegion of +/// the scf.if. So, canonicalize user of select in scf.if first. +static void canonicalizeSelectUsersInSCFIf(ModuleOp input) { + llvm::MapVector, SmallVector> + usersNeedreplaced; + input.walk([&](arith::SelectOp selectOp) { + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + Value trueVal = selectOp.getOperand(1); + Value falseVal = selectOp.getOperand(2); + Value resVal = selectOp.getResult(); + for (auto *condUser : condition.getUsers()) { + if (!llvm::isa(condUser)) + continue; + scf::IfOp ifOp = llvm::cast(condUser); + for (auto *resUser : resVal.getUsers()) { + if (ifOp->isProperAncestor(resUser)) { + if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) != + nullptr) { + // The user is inside the ThenRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back( + resUser); + } else { + // The user is inside the ElseRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back( + resUser); + } + } + } + } + }); + + // Replace the operand of user. + for (auto [replacedSrcAndDst, users] : + llvm::make_early_inc_range(usersNeedreplaced)) { + Value srcVal = replacedSrcAndDst.first; + Value dstVal = replacedSrcAndDst.second; + for (Operation *user : llvm::make_early_inc_range(users)) { + srcVal.replaceUsesWithIf( + dstVal, [&](OpOperand &use) { return use.getOwner() == user; }); + } + } +} + +/// Return true if the select could be merged into the If without breaking SSA +/// rules. static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, DominanceInfo &dom) { // If needs to be dominated by the select. @@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - DominanceInfo dom(m); + canonicalizeSelectUsersInSCFIf(m); // Go over the arith.select ops, look if there is an if // with the same condition. + DominanceInfo dom(m); llvm::MapVector> selectToIf; m.walk([&](arith::SelectOp selectOp) { // Look if there is an if in the same block, with the same condition. diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp new file mode 100644 index 000000000000..e15b43960031 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -0,0 +1,345 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-schedule" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPULOOPSCHEDULING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +bool hasLatenciesAssigned(scf::ForOp forOp, + const DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + return true; + } + return false; +} + +CoarseSchedule scheduleKeyOps(scf::ForOp forOp, + const DenseMap &opLatency) { + llvm::MapVector opToStage; + // Find terminator for later reference + auto terminator = cast(forOp.getBody()->getTerminator()); + // Determine all operations that have a non-zero latency + SmallVector latOps; + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + latOps.push_back(&op); + } + // If no latency ops, nothing to schedule + if (latOps.empty()) + return CoarseSchedule(0); + + // Compute the longest path to the yield for each operation reachable + // from any latency operation. + DenseMap distance; + std::function computeDistance = [&](Operation *op) -> int { + auto it = distance.find(op); + if (it != distance.end()) + return it->second; + // Compute max distance among all users that are inside the loop body + int maxDist = -1; + for (Operation *user : op->getUsers()) { + // Only consider users inside the same block and not the terminator + Operation *inBlockUser = forOp.getBody()->findAncestorOpInBlock(*user); + if (!inBlockUser || inBlockUser == terminator) + continue; + int distUser = computeDistance(inBlockUser); + if (distUser > maxDist) + maxDist = distUser; + } + int lat = 0; + if (opLatency.count(op)) + lat = opLatency.lookup(op); + // If an op has no users (maxDist == -1) but has latency, we include its + // latency otherwise it contributes 0 to the distance. + int d = lat + (maxDist < 0 ? 0 : maxDist); + distance[op] = d; + return d; + }; + + // Compute distances for all latency-starting ops + int maxDistance = 0; + for (Operation *latOp : latOps) { + int d = computeDistance(latOp); + if (d > maxDistance) + maxDistance = d; + } + + // Assign stage to each op reachable from a latency op + for (auto &kv : distance) { + Operation *op = kv.first; + int dist = kv.second; + // We only schedule ops that are downstream of a latency op + // (had a non-negative distance due to a latency op). + if (dist >= 0) + opToStage[op] = maxDistance - dist; + } + + auto stages = llvm::make_second_range(opToStage); + int maxStage = *std::max_element(stages.begin(), stages.end()); + CoarseSchedule schedule(maxStage + 1); + SmallVector clusters(maxStage + 1); + for (int i = 0; i <= maxStage; i++) { + clusters[i] = schedule.clusters.newAtBack(); + } + CoarseSchedule::Cluster epilogue = schedule.clusters.newAtBack(); + // Assign ops to the clusters in reverse-stage order; + // ops with higher stage numbers are assigned first. This way we will + // end up with roughly reverse program order in the clusters. + for (auto [op, stage] : opToStage) { + if (isa(op)) { + schedule.insert(op, stage, epilogue); + continue; + } + schedule.insert(op, stage, clusters[maxStage - stage]); + } + + return schedule; +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.numStages; + auto getNestedOperands = [](Operation *op) -> SmallVector { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + true); + } + } + } + } + } + } +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.numStages; + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + if (!ifsToStage.empty()) { + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + } + + // Other IfOps should be pushed to the end. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifsToStage.count(ifOp) == 0) { + schedule.insertIfAbsent(ifOp, numStages - 1, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue) { + int numStages = schedule.numStages; + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster; + if (schedule.count(op)) + opCluster = schedule[op].second; + else + opCluster = opToCluster[op]; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +}; // namespace + +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency) { + if (!hasLatenciesAssigned(forOp, opLatency)) + return; + // Based on the latencies, schedule the key ops to the stages. + CoarseSchedule schedule = scheduleKeyOps(forOp, opLatency); + if (schedule.empty()) + return; + LLVM_DEBUG({ + LDBG("Initial coarse schedule:"); + schedule.dump(); + }); + // Schedule the dependencies + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + schedule.dump(); + }); + scheduleDependencies(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + scheduleDistanceOneDependencies(forOp, schedule); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + scheduleRemainingToLastStage(forOp, schedule, afterPrologue); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Write the schedule to the IR + schedule.serialize(forOp); +} + +class TritonGPULoopSchedulingPass + : public impl::TritonGPULoopSchedulingBase { +public: + using impl::TritonGPULoopSchedulingBase< + TritonGPULoopSchedulingPass>::TritonGPULoopSchedulingBase; + + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + void runOnOperation() override { + // Go over the interesting ops and assign latencies (based on the + // numStages) to the them, trying to populate the allowed stages. This + // step will be at some point extracted to separate pass that will be run + // only for loops missing the latency information. + DenseMap opLatency = + assignLatencies(getOperation(), numStages); + // numStages should not be used below this point. We should know everything + // based on the assigned stages + + // Schedule the loops + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + if (loops.empty()) + return; + + for (auto forOp : loops) { + scheduleLoop(forOp, opLatency); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp index dd9b4ad139f5..96dc5112f1b4 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -38,13 +38,7 @@ void setUseAccFlag(Operation *op, Value useAcc) { } bool isConstantZeroTensor(Value v) { - auto constOp = v.getDefiningOp(); - if (!constOp) - return false; - auto splat = mlir::dyn_cast(constOp.getValue()); - if (!splat) - return false; - return splat.getSplatValue().getValue().convertToFloat() == 0.0f; + return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat())); } std::optional> findZeroInitOp(Value accUse, @@ -65,6 +59,8 @@ std::optional> findZeroInitOp(Value accUse, return std::nullopt; } if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; if (isConstantZeroTensor(selOp.getTrueValue()) || isConstantZeroTensor(selOp.getFalseValue())) { return std::make_pair(selOp, 0); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..01e8acf25842 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -1,9 +1,11 @@ +#include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -15,6 +17,125 @@ namespace gpu { namespace { +// Helpers + +// Returns whether we can hoist DotOp Encoding through `op`. +// Roughly, whether op is elementwise and thus threads don't need +// to exchange elements. But some ops are not currently supported even though +// they meet that criterion. +bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(op) && dotOpEnc.getKWidth() == 1) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op)) { + Type opType = getElementTypeOrSelf(op->getOperand(0)); + if (opType.isInteger(1)) + return false; + } + + return true; +} + +// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A +// is in registers). +bool canHoistDotOpEncV3(Operation *op) { + // Must have exactly one result and at least one operand + if (op->getNumOperands() == 0 || op->getNumResults() != 1) + return false; + + auto isBlockedOrDotOpRankedTensor = [](Type ty) { + auto tensorTy = dyn_cast(ty); + if (!tensorTy) + return false; + return isa( + tensorTy.getEncoding()); + }; + + // Operands and results must be of RankedTensorType and Blocked or DotOp + if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) && + all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor))) + return false; + + // Only consider custom conversions or arith ops. + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Downcasting not currently supported; it will likely require minor + // adjustments in sharedToDotOperandMMv2 + auto oprType = getElementTypeOrSelf(op->getOperand(0)); + auto resType = getElementTypeOrSelf(op->getResult(0)); + if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth()) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op) && oprType.isInteger(1)) + return false; + + return true; +} + +// Helper to perform a "deep" clone of the given slice (i.e., set of ops), +// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, +// and sliceMap the IRMapping that maps the ops and result values of the +// original slice to those in the cloned slice. +auto cloneSlice(PatternRewriter &rewriter, + const SetVector &slice) { + IRMapping sliceMap; + SetVector newSlice; + + // First pass: clone ops; the result values are cloned as well, but the + // operands still refer to the original result values + for (Operation *op : slice) { + rewriter.setInsertionPoint(op); + auto newOp = rewriter.clone(*op); + newSlice.insert(newOp); + sliceMap.map(op, newOp); + for (auto [result, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + assert(result != newResult); + sliceMap.map(result, newResult); + } + } + + // Second pass: replace operand references in cloned ops to point to cloned + // values + for (auto [op, newOp] : sliceMap.getOperationMap()) + for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) + continue; + + newOp->setOperand(oprIdx, sliceMap.lookup(operand)); + } + + return std::make_tuple(newSlice, sliceMap); +} + // Given // convert(trans(src)) #dot_operand -> // convert(local_load(trans(alloc(src)))) @@ -31,12 +152,12 @@ class SwizzleShmemConvert : public OpRewritePattern { if (!trans || trans.getOrder() != ArrayRef{1, 0}) return failure(); - auto srcTy = dyn_cast(trans.getSrc().getType()); + RankedTensorType srcTy = trans.getSrc().getType(); if (auto srcCvt = trans.getSrc().getDefiningOp()) { srcTy = srcCvt.getSrc().getType(); } - auto sharedLoadTy = cast(cvtOp.getType()); + RankedTensorType sharedLoadTy = cvtOp.getType(); auto cvtEncoding = dyn_cast(sharedLoadTy.getEncoding()); if (!cvtEncoding) @@ -49,9 +170,9 @@ class SwizzleShmemConvert : public OpRewritePattern { // Set needTrans to true here. newInnerCvtEnc is computed based on // argEncoding which is before the transpose. Without needTrans we will // compute vec and maxPhase based on incorrect m, n and k size of mma. The - // type inference of TransOp simply swap the order but doesn't fix the vec - // and maxPhase for the YType, hence it would causing incorrect swizzling - // code. + // type inference of MemDescTransOp simply swap the order but doesn't fix + // the vec and maxPhase for the YType, hence it would causing incorrect + // swizzling code. auto newInnerCvtEnc = SharedEncodingAttr::get(getContext(), cvtEncoding, srcTy.getShape(), /*order=*/getOrder(srcTy.getEncoding()), @@ -66,8 +187,8 @@ class SwizzleShmemConvert : public OpRewritePattern { MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerCvtEnc, sharedMemorySpace), trans.getSrc()); - auto newTrans = rewriter.create(trans.getLoc(), alloc, - ArrayRef({1, 0})); + auto newTrans = rewriter.create(trans.getLoc(), alloc, + ArrayRef({1, 0})); rewriter.replaceOpWithNewOp(trans, sharedLoadTy, newTrans); return success(); } @@ -111,7 +232,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,16 +248,7 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(src) && !isPureUnaryInlineAsm(src) && - src->getDialect()->getTypeID() != TypeID::get()) - return failure(); - - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(src)) + if (!canHoistDotOpEncV2(src, dotOpEnc)) return failure(); // Check that the conversion is transitively dependent on a load, and all @@ -165,12 +278,7 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(currOp)) { foundLoad = true; } else if (foundLoad) { - // Bail out if there exists an op after Load that is not FpToFp, - // Bitcast, or Arith. - if (!isa(currOp) && - !isPureUnaryInlineAsm(currOp) && - currOp->getDialect()->getTypeID() != - TypeID::get()) + if (!canHoistDotOpEncV2(currOp, dotOpEnc)) return failure(); } } @@ -224,7 +332,7 @@ class FuseTransHopper : public OpRewritePattern { MemDescType allocType = allocOp.getType(); auto allocEncoding = cast(allocType.getEncoding()); - TensorOrMemDesc srcTy = trans.getSrc().getType(); + RankedTensorType srcTy = trans.getSrc().getType(); // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 // without transpose for other data types.) @@ -253,8 +361,8 @@ class FuseTransHopper : public OpRewritePattern { allocType.getMemorySpace()); auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, trans.getSrc()); - rewriter.replaceOpWithNewOp(allocOp, newAlloc, - ArrayRef({1, 0})); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); return success(); } }; @@ -286,11 +394,12 @@ struct MMAV3UseRegOperand dstEnc.getVersionMajor() != 3) return failure(); auto srcTy = cast(alloc.getSrc().getType()); + auto kWidth = 32 / srcTy.getElementTypeBitWidth(); auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!isMmaToDotShortcut(srcTy, newTy)) + if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) return failure(); Value newOperand = @@ -300,6 +409,150 @@ struct MMAV3UseRegOperand } }; +// MMAV3's analog of HoistLayoutConversion, for operand A only; will make +// WarpGroupDot accept operand A in registers instead of shmem. +// +// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot +// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; +// warp_group_dot +// +// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a +// time and requires multiple passes, this pattern will directly hoist the +// convert to the right place in one pass. +// +// Or, to be more precise, this pattern deletes the local_alloc op and inserts a +// convert_layout op after each load that warp_group_dot uses; so this is not +// simply hoisting a convert_layout op up as in V2, but can be considered as +// first changing local_alloc to convert_layout and then hoisting, which results +// in WGMMA now accepting operand A in DotOp layout rather than Shared. +struct MMAV3HoistLayoutConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, + PatternRewriter &rewriter) const override { + // Can only hoist operand 0 + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return rewriter.notifyMatchFailure( + dotOp, "operand A must be produced by local_alloc"); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return rewriter.notifyMatchFailure( + dotOp, "requires Shared encoding for operand A"); + + // Step 1: Performs checks for early stop + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + if (!srcEnc) + return rewriter.notifyMatchFailure( + alloc, "requires src to have Blocked encoding"); + + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 3) + return rewriter.notifyMatchFailure( + dotOp, "requires result in NvidiaMma encoding"); + + // Step 2: Obtain slice of ops between load/constant and local_alloc + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *op) { + // Stop before Load, ConstantOp, or LocalLoad + return (op->getParentRegion() == alloc->getParentRegion()) && + !isa(op) && + (op->getNumOperands() != 0); + }; + getBackwardSlice(alloc.getOperation(), &slice, opt); + + // Step 3: Verify slice can be hoisted through + if (slice.empty()) + return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); + + // We define frontierOp as an op outside this slice whose result is used by + // an op in this slice. We must eventually convert the result of all + // frontierOps to DotOperandEncoding. This is done via the insertion of + // ConvertLayout after each frontierOp. We currently support frontierOp to + // be load or constant. + for (Operation *currOp : slice) { + if (!canHoistDotOpEncV3(currOp)) + return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); + + // We previously ensured that all ops in slice have at least one operand + for (auto operand : currOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) { + // ensure frontierOp is load or constant + if (!isa(defOp)) + return rewriter.notifyMatchFailure(defOp, + "must be load or constant"); + } + } + } + + // Step 4: Clone slice + auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); + + // Step 5: Modify the cloned slice to have dotOp encoding. + // Before: load #blocked; (elementwise #blocked)+; local_alloc; + // warp_group_dot After: load #blocked; convert_layout #dot_op; + // (elementwise #dot_op)+; warp_group_dot + // + // Specifically, this step will change all value types from #blocked to + // #dot_op encoding in the cloned slice, and for those values produced by + // frontierOps (i.e., outside the slice), we will insert convert_layout's + // after the frontierOp. + auto srcTy = cast(alloc.getSrc().getType()); + Type inputEltTy = srcTy.getElementType(); + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); + + for (auto op : newSlice) { + // Step 5a: If any operand is defined by a frontierOp, we must insert a + // convert_layout(#dot_op) after the frontierOp and before currOp + for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { + + auto defOp = operand.getDefiningOp(); + + // defOp is not frontier (i.e. it's within slice); no need to convert + // the layout of its result + if (newSlice.contains(defOp)) + continue; + + // We checked earlier that all operands are ranked tensors + auto operandTy = cast(operand.getType()); + auto operandEltTy = operandTy.getElementType(); + + Type cvtTy = RankedTensorType::get( + operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + rewriter.setInsertionPoint(op); + auto cvt = + rewriter.create(defOp->getLoc(), cvtTy, operand); + + op->setOperand(oprIdx, cvt); + } + + // Step 5b: Change the result to have DotOp rather than Blocked encoding + auto resTy = cast(op->getResult(0).getType()); + op->getResult(0).setType(RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), dotOperandEnc)); + } + + // Step 6: replace LHS operand with alloc's parent in the cloned slice + // This changes the warpGroupDot to accept a DotOp tensor as operand A + // instead of a Shared memdesc. + auto newDotOperand = sliceMap.lookup(alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, + [&]() { dotOp.setOperand(0, newDotOperand); }); + + return success(); + } +}; + } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -321,6 +574,7 @@ class TritonGPUOptimizeDotOperandsPass auto ret = pm.run(m); mlir::RewritePatternSet patterns(context); + patterns.add(context); patterns.add(context); if (this->hoistLayoutConversion.getValue()) patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp new file mode 100644 index 000000000000..f274363730c4 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -0,0 +1,252 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-pipeline-schedule" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Return true if the preconditions for pipelining the loop are met. +bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + return true; +} + +bool canHaveSharedEncoding(tt::LoadOp op) { + // If used by an user with DotOp encoding, all the uses must be compatible. + bool incompatible = false; + getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); + if (incompatible) + return false; + // If the load is used by a LocalAllocOp, all the users need to have the same + // encoding. + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return false; + } + return true; + } + return true; +} + +bool isSmallLoad(tt::LoadOp loadOp, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return true; + auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + return width < 32; +} + +bool isPipeliningBeneficial(Operation *op, Operation *finalUser, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + if (auto loadOp = dyn_cast(op)) { + if (isSmallLoad(loadOp, axisInfoAnalysis)) { + LDBG("Load " << *loadOp << " is too small for pipelining"); + return false; + } + } + if (isa(op)) + return true; + if (isa(finalUser) && + getMMALoadType(op) == MMALoadType::DoNotPipeline) { + LDBG("Load " << *op << " used by WarpGroupDotOp with incompatible layout"); + return false; + } + if (!canHaveSharedEncoding(cast(op))) { + LDBG("Load " << *op << " cannot have shared encoding"); + return false; + } + + return true; +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +llvm::MapVector +loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadOpToIndLevel; + DenseSet seen; + DenseSet excluded; + + std::function dfs = + [&](Operation *op, Operation *finalUser, int distance) { + if (!seen.insert(op).second || excluded.count(op)) + return; + if (isa(op)) { + if (!isPipeliningBeneficial(op, finalUser, axisInfoAnalysis)) + return; + if (loadOpToIndLevel.count(op)) { + int level = loadOpToIndLevel[op]; + if (level != distance) { + // If we have multiple uses at different distances, we don't know + // which one to pick. + LDBG("Load " << *op + << " has multiple uses at different distances:" + << level << " and " << distance); + loadOpToIndLevel.erase(op); + excluded.insert(op); + return; + } + } else { + LDBG("Load " << *op << " considered for pipelining with distance " + << distance); + loadOpToIndLevel[op] = distance; + } + finalUser = op; + distance++; + } + for (Value operand : op->getOperands()) { + if (op->hasTrait()) { + // Heuristic: only pipeline A and B operands of the dot op. + if (operand == op->getOperand(2)) + continue; + } + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, finalUser, distance); + } + } + }; + + bool seenDot = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seenDot = true; + seen.clear(); + dfs(&op, &op, 0); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (pipelineWithoutDot && !seenDot) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, &op, 0); + } + } + + return loadOpToIndLevel; +} + +} // namespace + +// Look for load ops that directly or indirectly feed into dot ops. Based +// on the requested number of stages assign the latencies in a way that +// cover all the stages with the sum of latencies in the chain from the first +// load to the final dot op. +DenseMap assignLatencies(ModuleOp moduleOp, + int defaultNumStages) { + auto getNumStagesOrDefault = [defaultNumStages](scf::ForOp forOp) -> int { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return defaultNumStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + }; + + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (preCondition(forOp) && getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + if (loops.empty()) + return DenseMap(); + + DenseMap opLatency; + for (auto forOp : loops) { + int numStages = getNumStagesOrDefault(forOp); + bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + llvm::MapVector loadOpToIndLevel = + loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, axisInfoAnalysis); + if (loadOpToIndLevel.empty()) + continue; + + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + for (auto iter = loadOpToIndLevel.begin(); + iter != loadOpToIndLevel.end();) { + if (iter->second >= numStages - 1) + iter = loadOpToIndLevel.erase(iter); + else + ++iter; + } + + // Calculate the stage distance between applicable loads. + auto vals = llvm::make_second_range(loadOpToIndLevel); + int maxIndirectionLevel = + vals.empty() ? 0 : *std::max_element(vals.begin(), vals.end()); + unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1); + + for (auto [loadOp, dist] : loadOpToIndLevel) { + opLatency[loadOp] = loadLatency; + } + } + return opLatency; +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index e946735e2374..f0fe8d43f438 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -34,94 +34,144 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; -// TODO: We can extra some helpers into common utilities once we add more +// TODO: We can extract some helpers into common utilities once we add more // schedules. namespace { struct LoadInfo { - // Layout of the data in the shared memory. + // Layout of the data in shared memory. ttg::SharedEncodingAttr sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; - bool loadIsMMAV3 = false; + bool isMMAv3Shared = false; + bool isMMAv3Registers = false; int distToUse = 0; bool usedByDot = false; }; } // namespace -static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, - Value insertIdx, Value extractIdx, - tt::CoarseSchedule &schedule, - tt::CoarseSchedule::Cluster prefetchCluster, - llvm::MapVector &loadToInfo, - int numStages) { - OpBuilder builder(forOp); - Value zero = builder.create(forOp.getLoc(), 0, 32); +class OpBuilderWithStage : public OpBuilder { +public: + explicit OpBuilderWithStage(Operation *op, + OpBuilder::Listener *listener = nullptr) + : OpBuilder(op, listener) {} + explicit OpBuilderWithStage(Region ®ion, Listener *listener = nullptr) + : OpBuilder(region, listener) {} + + template + OpTy createWithStage(Location location, int stage, int cluster, + Args &&...args) { + OpTy op = OpBuilder::create(location, std::forward(args)...); + tt::setStageCluster(op, stage, cluster); + return op; + } + using OpBuilder::create; +}; + +static bool sameStageCluster(Operation *op1, Operation *op2) { + auto [s1, c1] = tt::getStageCluster(op1); + auto [s2, c2] = tt::getStageCluster(op2); + return s1 == s2 && c1 == c2; +} + +// Return user of a loadOp with the lowest stage, if two users have the +// same stage, return the user with lower cluster. +static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) { + Operation *firstUser = nullptr; + for (Operation *user : loadOp->getUsers()) { + if (user->getBlock() == loadOp->getBlock()) { + auto [stage, clusterId] = tt::getStageCluster(user); + // Update FirstUse if this use has lower stage or lower cluster. + if (!firstUser) + firstUser = user; + else { + auto [stageForFirstUse, clusterForFirstUse] = + tt::getStageCluster(firstUser); + if (stage < stageForFirstUse || + (stage == stageForFirstUse && clusterId < clusterForFirstUse)) + firstUser = user; + } + } + } + return firstUser; +} + +static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + llvm::MapVector &loadToInfo, + int numStages, int maxClusterId) { + int retCode = -1; + OpBuilderWithStage builder(forOp); + auto opPair = tt::getStageCluster(loadOp); + auto *firstUse = getFirstUseOfPipelinedLoad(loadOp); + auto [stageForFirstUse, clusterForFirstUse] = tt::getStageCluster(firstUse); + int stage = opPair.first, clusterId = opPair.second; + + Value zero = builder.createWithStage( + forOp.getLoc(), stage, clusterId, 0, 32); // Replace the load with insert/extract slice. builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); Value other = loadOp.getOther(); + ttg::MemDescType allocTy = cast(alloc.getType()); + + auto convertBlockLayout = [&](Value src, ttg::BlockedEncodingAttr enc) { + auto ty = cast(src.getType()); + auto newTy = RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = builder.createWithStage( + loadOp->getLoc(), stage, clusterId, newTy, src); + return cvt.getResult(); + }; + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { // For inexpensive loads that do not directly feed into dot ops // we want to use optimal layout for the data. ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; - auto convertBlockLayout = [&](Value src) { - auto ty = cast(src.getType()); - auto newTy = - RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); - auto cvt = - builder.create(loadOp->getLoc(), newTy, src); - return cvt.getResult(); - }; - src = convertBlockLayout(src); + src = convertBlockLayout(src, encoding); if (mask) - mask = convertBlockLayout(mask); + mask = convertBlockLayout(mask, encoding); if (other) - other = convertBlockLayout(other); + other = convertBlockLayout(other, encoding); } - tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); - tt::MemDescType subviewTy = tt::MemDescType::get( + ttg::MemDescType subviewTy = ttg::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); - auto view = - builder.create(loc, subviewTy, alloc, copyOffsets); - Operation *copy = builder.create( - loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), - loadOp.getIsVolatile()); - Operation *commmit = - builder.create(loc, copy->getResult(0)); - Operation *wait = - builder.create(loc, commmit->getResult(0), 0); - - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(copy, stage, cluster); - schedule.insert(commmit, stage, cluster); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); + auto view = builder.createWithStage( + loc, stage, clusterId, subviewTy, alloc, copyOffsets); + Operation *copy = builder.createWithStage( + loc, stage, clusterId, src, view, mask, other, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + Operation *commmit = builder.createWithStage( + loc, stage, clusterId, copy->getResult(0)); + Operation *wait = builder.createWithStage( + loc, stageForFirstUse, clusterForFirstUse, commmit->getResult(0), 0); + + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); loadOffsets[0] = extractIdx; - auto viewLoad = - builder.create(loc, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + auto viewLoad = builder.createWithStage( + loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets); + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); - replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); } else { SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { if (auto alloc = dyn_cast(user)) { - replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); allocsToErase.push_back(alloc); } } @@ -129,16 +179,20 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, alloc.erase(); } - auto sharedLoad = builder.create( - loc, loadOp.getType(), viewLoad, wait->getResult(0)); + auto sharedLoad = builder.createWithStage( + loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), viewLoad, + wait->getResult(0)); auto result = sharedLoad->getResults(); // Create a select for non-zero other values as they are not handled by // AsyncCopyGlobalToLocalOp for now. Value other = loadOp.getOther(); if (other && !isZeroConst(other)) { - auto select = builder.create( - loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + auto select = builder.createWithStage( + loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), + // Use the mask operand from the original load, not the one with a + // potentially transformed layout. + loadOp.getMask(), sharedLoad.getResult(), other); result = select->getResults(); } @@ -146,58 +200,69 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, // Prefetch load if is not MMAV3 and is used by the dot. if (loadToInfo[loadOp].usedByDot) { - schedule.insert(wait, numStages - 2, prefetchCluster); - schedule.insert(viewLoad, numStages - 2, prefetchCluster); + assert(stageForFirstUse >= 1); + tt::setStageCluster(wait, stageForFirstUse - 1, maxClusterId + 1); + tt::setStageCluster(viewLoad, stageForFirstUse - 1, maxClusterId + 1); + retCode = stageForFirstUse - 1; } } loadOp.erase(); + return retCode; } -static void createTMAAsyncCopy( - scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, - Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, - Value phase, tt::CoarseSchedule &schedule, - llvm::MapVector &loadToInfo, int numStages) { +static void +createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, + Value alloc, Value insertIdx, Value extractIdx, + Value barrier, Operation *waitOp, Value phase, + llvm::MapVector &loadToInfo, + int numStages) { assert(phase && "Phase value is required for TMA async copy."); - OpBuilder builder(forOp); + OpBuilderWithStage builder(forOp); + auto [stage, clusterId] = tt::getStageCluster(loadOp); + auto *firstUse = getFirstUseOfPipelinedLoad(loadOp); + auto [stageForFirstUse, clusterForFirstUse] = tt::getStageCluster(firstUse); + Attribute sharedMemorySpace = - triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); - Value zero = builder.create(forOp.getLoc(), 0, 32); + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + Value zero = builder.createWithStage( + forOp.getLoc(), stage, clusterId, 0, 32); builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); - tt::MemDescType allocTy = cast(alloc.getType()); + ttg::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; - tt::MemDescType subviewTy = tt::MemDescType::get( + ttg::MemDescType subviewTy = ttg::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); - auto view = - builder.create(loc, subviewTy, alloc, copyOffsets); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); + auto view = builder.createWithStage( + loc, stage, clusterId, subviewTy, alloc, copyOffsets); - Value pred = builder.create(loc, 1, 1); - Operation *copy = builder.create( - loc, loadOp.getDescPtr(), loadOp.getIndices(), barrier, view, pred); + Value pred = builder.createWithStage(loc, stage, + clusterId, 1, 1); + Value tmaPtr = + builder.createWithStage( + loc, stage, clusterId, loadOp.getDesc()); + Operation *copy = builder.createWithStage( + loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred); - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; - auto [stage, cluster] = schedule[loadOp]; - schedule.erase(loadOp); - schedule.insert(copy, stage, cluster); + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; builder.setInsertionPointAfter(waitOp); // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); loadOffsets[0] = extractIdx; - auto viewLoad = - builder.create(loc, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + auto viewLoad = builder.createWithStage( + loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets); + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); - replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); } else { SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { if (auto alloc = dyn_cast(user)) { - replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); allocsToErase.push_back(alloc); } } @@ -205,60 +270,16 @@ static void createTMAAsyncCopy( alloc.erase(); } - auto sharedLoad = builder.create( - loc, loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + builder.setInsertionPointAfter(viewLoad); + auto sharedLoad = builder.createWithStage( + loc, stage, clusterId, loadOp.getType(), + viewLoad /*,wait->getResult(0)*/); auto result = sharedLoad->getResults(); loadOp->replaceAllUsesWith(result); } loadOp.erase(); } -// If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return the shared encoding that needs to be -// used to be compatible with users' layouts. If there are imcompatible shared -// encodings, raise assertion, since incompatible shared encoding has been -// handled in splitLoadsForIncompatible. -static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { - ttg::SharedEncodingAttr attr; - incompatible = false; - for (Operation *user : val.getUsers()) { - ttg::SharedEncodingAttr tempAttr; - if (user->getNumResults() != 1) - return std::nullopt; - if (auto memDesc = - dyn_cast(user->getResult(0).getType())) { - // First time we find a shared encoding in the chain, save it and try to - // use it if it is compatible with the other users. - tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) - .has_value()) - return std::nullopt; - } else { - if (!isa(user)) - return std::nullopt; - auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()).getEncoding()); - if (!dotOpEnc) - return std::nullopt; - auto srcTy = cast(val.getType()); - auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, - bitWidth, /*needTrans=*/false); - } - // Check that the shared encodings needed by the users are compatible. - if (attr != nullptr && attr != tempAttr) { - incompatible = true; - return std::nullopt; - } - attr = tempAttr; - } - return attr; -} - static ttg::BlockedEncodingAttr getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { Value src = loadOp.getPtr(); @@ -279,7 +300,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { } static std::optional -getSharedEncoding(Operation *loadOp, bool isMMAV3) { +getSharedEncoding(Operation *loadOp, bool isMMAV3Shared) { auto ty = cast(loadOp->getResultTypes()[0]); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); auto blockedOrder = ttg::getOrder(ty.getEncoding()); @@ -294,7 +315,7 @@ getSharedEncoding(Operation *loadOp, bool isMMAV3) { } else { order = blockedOrder; } - if (isMMAV3) { + if (isMMAV3Shared) { return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } @@ -326,450 +347,175 @@ getSharedEncoding(Operation *loadOp, bool isMMAV3) { ctaLayout); } -// Create a map from load ops to their indirection level and the -// final use of the load op (another load op, or a dot op). -// Indirection level is "0" for the load op directly used by the dot op, -// "1" for the load op used by the load op used by the dot op, and so on. -static llvm::SmallVector> -loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { - llvm::SmallVector> - loadOpToIndLevelAndUse; +static bool hasSharedEncodingHelper(Operation *loadOp) { + // If the load is used by a LocalAllocOp, use the same encoding as the allocs. + // If the allocs don't all have the same encoding, bail. + if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : loadOp->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return false; + } + return true; + } + return true; +} + +static llvm::SmallVector getDirectUserInBlock(Operation *loadOp) { + llvm::SmallVector users; DenseSet seen; + for (Operation *user : loadOp->getUsers()) { + if (!seen.insert(user).second) + continue; + if (user->getBlock() == loadOp->getBlock()) + users.push_back(user); + } + return users; +} - std::function dfs = - [&](Operation *op, int distance, Operation *use) { +// When loop doesn't have num_stages attributes, we will look for any load or +// dot (only the first one in the chain). With the attribute we should look for +// any op, but also only the first one. +static llvm::SmallVector +getTransitiveUserInBlock(Operation *baseOp, scf::ForOp &forOp) { + llvm::SmallVector users; + DenseSet seen; + bool loopHasAttribute = forOp->hasAttr(tt::kNumStagesAttrName); + std::function dfs = + [&](Operation *op, Operation *baseOp, bool anyOp) { if (!seen.insert(op).second) return; - if (isa(op)) { - // TODO: What if there are multiple uses at different distances? - loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); - use = op; - distance++; - } - for (Value operand : op->getOperands()) { - Value v = operand; - Operation *defOp = v.getDefiningOp(); - if (defOp && defOp->getBlock() == op->getBlock()) { - dfs(defOp, distance, use); + if (op != baseOp) { + if (anyOp) { + // Only track the first op in the dependence chain. + users.push_back(op); + return; + } + if (isa(op) || + op->hasTrait()) { + // Stop recursion when hitting a LoadOp or a DotOp. + users.push_back(op); + return; } } + for (Operation *user : op->getUsers()) + if (user->getBlock() == op->getBlock()) + dfs(user, baseOp, anyOp); }; - - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!op.hasTrait()) - continue; + // We are matching the behavior before refactoring: + // For loops without num_stage attributes, we check for dot users. + // For loops with num_stage attributes, we check for dot users, if there are + // no dot users, we check for direct users. + dfs(baseOp, baseOp, false /*anyOp*/); + if (loopHasAttribute) { seen.clear(); - dfs(&op, 0, &op); - } - - // If the loop has numStages attribute, also consider pipelining other loads - // that are not directly used by dot ops. - if (forOp->hasAttr(tt::kNumStagesAttrName)) { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!isa(op)) - dfs(&op, 0, &op); - } + dfs(baseOp, baseOp, true /*anyOp*/); } - - return loadOpToIndLevelAndUse; -} - -static bool loadIsMMAv3(Operation *loadOp) { - if (!loadOp->hasOneUse()) - return false; - auto alloc = dyn_cast(*loadOp->getUsers().begin()); - if (!alloc) - return false; - auto sharedEnc = cast(alloc.getType().getEncoding()); - if (!sharedEnc.getHasLeadingOffset()) - return false; - - // MMA V3 case. - auto newOrder = sharedEnc.getOrder(); - auto ty = cast(loadOp->getResultTypes()[0]); - auto oldOrder = ttg::getOrder(ty.getEncoding()); - - // The operand of MMAv3 is in SharedEncoding and its order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - return oldOrder == newOrder; + return users; } static llvm::MapVector -assignMemoryLayouts(llvm::SmallVector> - &loadOpToIndLevelAndUse, +assignMemoryLayouts(scf::ForOp &forOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { llvm::MapVector loadToInfo; - for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { - if (loadToInfo.count(op)) - // TODO pawel: err, we'd need to verify that the distance is the same - continue; - LoadInfo loadInfo; - - if (auto loadOp = dyn_cast(op)) { - assert(!isLoadFromTensorPtr(loadOp) && - "Block ptr should have been lowered before this pass."); - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - if (auto mask = loadOp.getMask()) - vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - continue; - auto ty = - cast(tensorTy.getElementType()).getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - LDBG("Load " << *loadOp << " has width " << width); - if (width < 32) - continue; - } - - if (use->hasTrait()) { - loadInfo.usedByDot = true; - if (loadIsMMAv3(op)) { - loadInfo.loadIsMMAV3 = true; - loadInfo.sharedEncoding = - getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else if (isa(op)) { - loadInfo.sharedEncoding = - getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else if (auto dot = dyn_cast(use)) { - bool incompatible = false; - loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) - .value_or(nullptr); - // If we can't agree on a shared encoding skip pipelinig the load. - if (incompatible) - continue; - - // HACK: Triton LLVM codegen has a bug where local_loads from #shared to - // #mma layout can lead to invalid code if the loaded shape is smaller - // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with - // tile {16,8} is bad because 1 < 8). To work around this, don't - // pipeline such loads. - // - // The codegen bug is caught by an assertion, so if you think you've - // fixed it, feel free to delete this code and see if the assert still - // fails. :) - if (!loadInfo.sharedEncoding) { - if (auto dotEnc = dyn_cast( - dot.getResult().getType().getEncoding())) { - auto loadTy = cast(op->getResultTypes()[0]); - auto mmaInstrShape = dotEnc.getInstrShape(); - if (loadTy.getRank() < mmaInstrShape.size()) - continue; - bool ok = true; - for (int i = 0; i < mmaInstrShape.size(); i++) { - if (loadTy.getShape()[loadTy.getRank() - mmaInstrShape.size() + - i] < mmaInstrShape[i]) { - ok = false; - break; - } - } - // If this load might trigger the bug, don't do the fallback logic - // below, which might allow the load to be pipelined. - if (!ok) - continue; - } - } - } - } else if (auto loadOp = dyn_cast(use)) { - // The use of this loadOp is another loadOp. If the use is not in the - // loadsToPipeline already, it means that the use is not valid for - // pipelining for some reason. We should skip this loadOp, too. Note that - // we have an assumption that distAndUse.second (i.e. the use of this - // loadOp) has already be processed in a previous loop iteration. This - // assumption is held by how loadOpsToIndirectionLevelAndUse recursively - // collects loadOpToIndLevelAndUse using DFS. - if (loadToInfo.count(loadOp) == 0) { - continue; - } - } - - // If we still don't have a shared encoding, try a "generic" shared - // encoding. - if (!loadInfo.sharedEncoding && !isa(use)) { - loadInfo.sharedEncoding = - getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) - .value_or(nullptr); - if (auto loadOp = dyn_cast(op)) { - loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); - } - } - - // If that still didn't work, bail on pipelining this load. - if (!loadInfo.sharedEncoding) { - continue; - } - loadToInfo[op] = loadInfo; - } - - return loadToInfo; -} - -static llvm::MapVector -scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, - DenseSet &rootUsers, int numStages) { - - ModuleOp moduleOp = forOp->getParentOfType(); - tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // Get all loads that are (transitively) used by dot ops and their distance - // to the dot op. - llvm::SmallVector> - loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); - LLVM_DEBUG({ - LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); - for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { - LDBG(" - load: " << *l); - LDBG(" at indirection level: " << i); - LDBG(" used by op: " << *u); - } - }); - if (loadOpToIndLevelAndUse.empty()) - return {}; - - // We assume loads with different dist are assigned to different stages. - // If numStages is 2, we will have no stage available for indirect loads - // with dist >= 1. In general, when dist is equal to numStages - 1, we - // should not pipeline it. - auto it = llvm::remove_if(loadOpToIndLevelAndUse, [=](auto op) { - return std::get<1>(op) >= numStages - 1; - }); - loadOpToIndLevelAndUse.erase(it, loadOpToIndLevelAndUse.end()); - - // Check which loads are good for pipelining, and assign them - // memory layouts. - llvm::MapVector loadToInfo = - assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); - - if (loadToInfo.empty()) - return {}; - - // Calculate the stage distance between applicable loads. - int maxIndirectionLevel = -1; - for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { - if (loadToInfo.count(loadOp) == 0) - continue; - maxIndirectionLevel = std::max(maxIndirectionLevel, dist); - } - unsigned stagesBetweenLoads = - ceil(numStages - 2, maxIndirectionLevel + 1); - - tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); - // Put the root uses of the loads in the last stage. - for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { - if (loadToInfo.count(loadOp) == 0) + // Go through all loads in the loop, check to see if they are pipelined. + llvm::DenseSet loadsToPipeline; + for (auto &op : forOp.getBody()->without_terminator()) { + if (!isa(op) && !isa(op)) continue; - // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be - // always present in the opInfo - if (!isa(use)) { - schedule.insert(use, numStages - 1, rootUsersCluster); - rootUsers.insert(use); - } - } - - SmallVector loadsClusters; - for (int i = 0; i < maxIndirectionLevel + 1; i++) { - loadsClusters.push_back(schedule.clusters.newAtBack()); - } - // Assign stages to the loads. - for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { - if (loadToInfo.count(loadOp) == 0) + if (loadToInfo.count(&op)) + // TODO pawel: err, we'd need to verify that the distance is the same continue; - int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - schedule.insert(loadOp, stage, loadsClusters[indLevel]); - } - - // Distance from the load to the use. - for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { - if (loadToInfo.count(loadOp) == 0) + if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) continue; - loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; - } - - return loadToInfo; -} -// Schedule the prologue and epilogue `if` ops in the loop, pushing them as -// close to the loop boundaries as possible. Return the cluster after the -// prologue (or the beginning of the loop if there is no prologue). -static tt::CoarseSchedule::Cluster -schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, - DenseSet &rootUsers, int numStages) { - tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); - - // Look for the IfOp that is in the backward slice any of the currently - // scheduled ops and put it at the beginning of the loop. - DenseMap ifsToStage; - // Go stage by stage. - for (int stage = 0; stage < numStages; stage++) { - for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { - if (stage_ != stage) + // Check stage for uses. If any direct use is in a different stage, treat it + // as a pipelined load. + bool isPipelined = false; + auto [sLoad, _cLoad] = tt::getStageCluster(&op); + auto directUsers = getDirectUserInBlock(&op); + LDBG("DirectUser for load " << op); + for (auto user : directUsers) { + LDBG(" - use: " << *user); + if (!user->hasAttr(mlir::triton::kLoopStageAttrName)) continue; - SetVector backwardSlice; - BackwardSliceOptions opt; - opt.omitBlockArguments = true; - getBackwardSlice((Operation *)op, &backwardSlice, opt); - - for (auto op : backwardSlice) { - if (auto ifOp = dyn_cast(op)) { - ifsToStage.insert({ifOp, stage}); - } - } - } - } - tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); - for (auto [ifOp, stage] : ifsToStage) { - schedule.insert(ifOp, stage, prologueCluster); - } - - // Look for the IfOp that is in the forward slice of the root users and put it - // at the end of the loop. - tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); - for (auto rootUser : rootUsers) { - SetVector forwardSlice; - getForwardSlice(rootUser, &forwardSlice); - - int stage = schedule[rootUser].first; - for (auto op : forwardSlice) { - scf::IfOp ifOp = dyn_cast(op); - if (ifOp == nullptr) { - // check if the op is in the body of an if op that's part of the loop - auto parentOp = op->getParentOp(); - if (parentOp != nullptr && - parentOp->getParentOp() == forOp.getOperation()) { - ifOp = dyn_cast(parentOp); - } - } - if (ifOp) { - schedule.insertIfAbsent(ifOp, stage, - epilogueCluster); // after prefetch extracts + auto [stage, _cluster] = tt::getStageCluster(user); + if (stage != sLoad) { + isPipelined = true; + break; } } - } - return afterPrologue; -} - -// Add dependencies of anchor ops to the coarse schedule. Schedule them to -// the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, - int numStages) { - SmallVector> - opsInOrder = schedule.getOpsInOrder(forOp); - // Schedule dependencies stage by stage. - for (int stage = 0; stage < numStages; stage++) { - for (auto [op, stage_, cluster] : opsInOrder) { - if (stage_ != stage) - continue; - schedule.insertDepsOfOp(op, stage, cluster, false); - } - } -} + if (!isPipelined) + continue; -// Find dependencies with distance of 1. They will go to the next stage, -// but in the cluster before the current op. -static void scheduleDistanceOneDependencies(scf::ForOp forOp, - tt::CoarseSchedule &schedule, - int numStages) { - auto getNestedOperands = [](Operation *op) -> SmallVector { - SmallVector operands; - op->walk([&](Operation *nestedOp) { - for (Value operand : nestedOp->getOperands()) { - if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) - operands.push_back(operand); + // Try to set shared encoding etc for the pipelined load. + auto users = getTransitiveUserInBlock(&op, forOp); + LLVM_DEBUG({ + LDBG("TransitiveUser for load " << op); + for (const auto user : users) { + LDBG(" - use: " << *user); } }); - return operands; - }; - // Mapping from the cluster to the cluster before it. - DenseMap - dist1Cluster; - for (auto &op : forOp.getBody()->without_terminator()) { - if (schedule.count(&op) == 0) - continue; - auto [stage, cluster] = schedule[&op]; - // Can't schedule past the last stage. - if (stage == numStages - 1) - continue; - for (Value operand : getNestedOperands(&op)) { - if (auto arg = dyn_cast(operand)) { - if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { - auto yieldOp = op.getBlock()->getTerminator(); - Value v = yieldOp->getOperand(arg.getArgNumber() - 1); - Operation *defOp = v.getDefiningOp(); - if (defOp && schedule.count(defOp) == 0) { - if (isa(defOp)) { - // Exception: Schedule loads with a distance of 1 together - // with the current op. - schedule.insertIfAbsent(defOp, stage, cluster); - schedule.insertDepsOfOp(defOp, stage, cluster, true); - } else { - if (dist1Cluster.count(&cluster) == 0) { - dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); - } - schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); - schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], - true); - } - } + loadsToPipeline.insert(&op); + LoadInfo loadInfo; + for (auto use : users) { + if (use->hasTrait()) { + LDBG("set shared encoding with dot user: " << *use); + auto mmaLoadType = getMMALoadType(&op); + auto dot = dyn_cast(use); + auto warpGroupDot = dyn_cast(use); + + loadInfo.usedByDot = true; + loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; + loadInfo.isMMAv3Registers = + (mmaLoadType == MMALoadType::Registers) && warpGroupDot; + + if (loadInfo.isMMAv3Shared) { + loadInfo.sharedEncoding = + getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else if (isa(op)) { + loadInfo.sharedEncoding = + getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else if (loadInfo.isMMAv3Registers || dot) { + bool incompatible = false; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op.getResult(0), incompatible) + .value_or(nullptr); } } - } - } -} -static void -scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, - tt::CoarseSchedule::Cluster afterPrologue, - int numStages) { - // Assign the rest of the ops to the last stage. - // Take care of the ordering of the ops - uses cannot be scheduled to the - // cluster before the definition. - DenseMap opToCluster; - for (auto &op : forOp.getBody()->without_terminator()) { - if (schedule.count(&op) == 0) { - opToCluster[&op] = afterPrologue; - } - } - SmallVector queue; - for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { - // We really only care about the producers from the last stage. - // Others will be scheduled before these ops anyway. - if (stage == numStages - 1) { - queue.push_back(op); - } - } - while (!queue.empty()) { - Operation *op = queue.pop_back_val(); - for (auto user : op->getUsers()) { - if (opToCluster.count(user)) { - tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; - tt::CoarseSchedule::Cluster opCluster; - if (schedule.count(op)) - opCluster = schedule[op].second; - else - opCluster = opToCluster[op]; - if (*userCluster < *opCluster) { - opToCluster[user] = opCluster; - queue.push_back(user); - } + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding && !isa(use)) { + LDBG("try generic shared encoding"); + loadInfo.sharedEncoding = + getSharedEncoding(&op, /*isMMAV3=*/loadInfo.isMMAv3Shared) + .value_or(nullptr); + if (auto loadOp = dyn_cast(op)) + loadInfo.blockedEncoding = + getBlockedEncoding(loadOp, axisInfoAnalysis); } } + loadToInfo[&op] = loadInfo; } - for (auto [op, cluster] : opToCluster) { - schedule.insert(op, numStages - 1, cluster); - } + // Make sure all loads in loadsToPipeline are in loadToInfo. + for (auto *load : loadsToPipeline) + assert(loadToInfo.count(load) && + "pipelined loads should have sharedEncoding"); + + return loadToInfo; } // Create an allocation that can hold distance number of loadOp shapes. @@ -777,15 +523,15 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned distance) { OpBuilder builder(forOp); Attribute sharedMemorySpace = - triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); auto ty = cast(loadOp->getResultTypes()[0]); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), distance); - Type memdescType = mlir::triton::MemDescType::get( - bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, - /*mutableMemory*/ true); - Value alloc = builder.create( - loadOp->getLoc(), memdescType, Value()); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + Value alloc = + builder.create(loadOp->getLoc(), memdescType, Value()); return alloc; } @@ -793,7 +539,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { OpBuilder builder(forOp); Attribute sharedMemorySpace = - triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); Location loc = forOp.getLoc(); auto context = forOp.getContext(); auto barrierCTALayout = @@ -801,14 +547,15 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); - Type barrierMemDescType = tt::MemDescType::get( + auto barrierMemDescType = ttg::MemDescType::get( {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, /*mutableMemory=*/true); - Type singleBarrierMemDescType = - tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, - sharedMemorySpace, /*mutableMemory=*/true); - Value barrierAlloc = builder.create( - loc, barrierMemDescType, Value()); + Type singleBarrierMemDescType = ttg::MemDescType::get( + {1}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/barrierMemDescType.getAllocShape()); + Value barrierAlloc = + builder.create(loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { Value idx = builder.create(loc, i, 32); Value barrierView = builder.create( @@ -824,6 +571,7 @@ struct AsyncLoad { Value alloc; Value barrier; Operation *waitOp = nullptr; + int firstUseStage, firstUseCluster; bool isTMALoad = false; }; @@ -831,8 +579,7 @@ struct AsyncLoad { // multiple loads is the schedule allows it. static void createTMABarrierAndWait( scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, - Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule, - SmallVector &barriers, + Value extractIdx, Value phase, int numBuffers, SmallVector &barriers, const llvm::MapVector &loadToInfo) { llvm::SmallDenseMap loadToAsyncLoad; for (AsyncLoad &asyncLoad : asyncLoads) { @@ -856,7 +603,7 @@ static void createTMABarrierAndWait( if (it != loadToInfo.end()) { // Special case for MMAv3 loads, we can ignore the alloc and only // consider uses of the alloc op since it will be removed. - if (it->second.loadIsMMAV3) { + if (it->second.isMMAv3Shared) { auto alloc = cast( (*loadInfo->loadOp->getUsers().begin())); if (alloc->getBlock() == loadBlock) { @@ -878,7 +625,9 @@ static void createTMABarrierAndWait( if (isa(nextOp)) { auto it = loadToAsyncLoad.find(nextOp); if (it != loadToAsyncLoad.end() && it->second->isTMALoad) { - addToGroup(it->second); + if (group.size() > 0 && + sameStageCluster(group[0]->loadOp, it->second->loadOp)) + addToGroup(it->second); } } nextOp = nextOp->getNextNode(); @@ -898,31 +647,34 @@ static void createTMABarrierAndWait( loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; } + auto [stage, cluster] = tt::getStageCluster(group[0]->loadOp); Value barrierAlloc = createBarrierAlloc(forOp, numBuffers); barriers.push_back(barrierAlloc); Location loc = forOp.getLoc(); - OpBuilder builder(forOp); + OpBuilderWithStage builder(forOp); Attribute sharedMemorySpace = - triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); - tt::MemDescType barrierTy = tt::MemDescType::get( - {1}, builder.getI64Type(), - cast(barrierAlloc.getType()).getEncoding(), - sharedMemorySpace, - /*mutableMemory=*/true); + ttg::SharedMemorySpaceAttr::get(builder.getContext()); + auto allocTy = cast(barrierAlloc.getType()); + ttg::MemDescType barrierTy = ttg::MemDescType::get( + {1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); builder.setInsertionPoint(group[0]->loadOp); - Value barrier = builder.create( - loc, barrierTy, barrierAlloc, ArrayRef({insertIdx})); - Value pred = builder.create(loc, 1, 1); - Operation *expect = builder.create( - forOp.getLoc(), barrier, sizeInBytes, pred); - auto [stage, cluster] = schedule[asyncLoads[0].loadOp]; - schedule.insert(expect, stage, cluster); + Value barrier = builder.createWithStage( + loc, stage, cluster, barrierTy, barrierAlloc, + ArrayRef({insertIdx})); + Value pred = builder.createWithStage(loc, stage, + cluster, 1, 1); + Operation *expect = builder.createWithStage( + forOp.getLoc(), stage, cluster, barrier, sizeInBytes, pred); builder.setInsertionPointAfter(group.back()->loadOp); - Value barrierViewWait = builder.create( - loc, barrierTy, barrierAlloc, ArrayRef({extractIdx})); - Operation *wait = - builder.create(loc, barrierViewWait, phase); + Value barrierViewWait = builder.createWithStage( + loc, group[0]->firstUseStage, group[0]->firstUseCluster, barrierTy, + barrierAlloc, ArrayRef({extractIdx})); + Operation *wait = builder.createWithStage( + loc, group[0]->firstUseStage, group[0]->firstUseCluster, + barrierViewWait, phase); // Update the async loads info. for (AsyncLoad *asyncLoad : group) { asyncLoad->barrier = barrier; @@ -931,10 +683,35 @@ static void createTMABarrierAndWait( } } +// This is similar to CoarseSchedule.createFinalSchedule. +static std::vector> +getFinalSchedule(scf::ForOp &forOp, int numStages) { + auto [minClusterId, maxClusterId] = tt::getMinMaxCluster(forOp); + SmallVector, 8> orderClusters(maxClusterId - + minClusterId + 1); + for (auto &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName) || + !op.hasAttr(mlir::triton::kLoopClusterAttrName)) + continue; + + auto [stage, clusterId] = tt::getStageCluster(&op); + assert(stage < numStages && "Op with invalid stage!"); + orderClusters[clusterId - minClusterId].push_back(&op); + } + std::vector> fSchedule; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto op : orderClusters[i]) { + auto [stage, _] = tt::getStageCluster(op); + fSchedule.push_back({op, stage}); + } + } + return fSchedule; +} + // Convert load ops into their asyn version and apply multi-buffering based on // the required number of buffers. static SmallVector -createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, +createAsyncOps(scf::ForOp &forOp, llvm::MapVector &loadToInfo, SmallVector &barriers, int numStages) { // Calculate the number of buffers needed for each load. @@ -946,8 +723,9 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, auto &rhs) { return lhs.distToUse < rhs.distToUse; })->distToUse; - bool hasMMAV3 = - llvm::any_of(loadToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) { + return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; + }); if (hasMMAV3) { // For MMAv3, we need an extra buffer as this is assumed in the wgmma // pipelining post-processing. @@ -967,6 +745,10 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, hasTMALoad = true; asyncLoads.back().isTMALoad = true; } + auto *firstUse = getFirstUseOfPipelinedLoad(loadOp); + auto [firstUseStage, firstUseCluster] = tt::getStageCluster(firstUse); + asyncLoads.back().firstUseStage = firstUseStage; + asyncLoads.back().firstUseCluster = firstUseCluster; } IRRewriter builder(forOp.getContext()); @@ -1001,6 +783,7 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, phase = newForOp.getBody()->getArgument(newOperandIndex + 2); } + // FIXME: loads can be in different (stage, cluster) // Create two counters for the insert and extract indices to avoid creating // long liverange. builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); @@ -1018,21 +801,18 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, phase = builder.create(loc, cndExt, phase, nextPhase); } createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase, - numBuffers, schedule, barriers, loadToInfo); - - // Create a cluster for the prefetches. It may end up being empty, but this - // is OK. - tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + numBuffers, barriers, loadToInfo); + auto [_, maxClusterId] = tt::getMinMaxCluster(forOp); for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, - schedule, prefetchCluster, loadToInfo, numStages); + loadToInfo, numStages, maxClusterId); } else { auto descLoad = cast(asyncLoad.loadOp); createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, - schedule, loadToInfo, numStages); + loadToInfo, numStages); } } SmallVector newYieldOperands = {insertIdx, extractIdx}; @@ -1041,22 +821,32 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, // Patch the yield with the updated counters. appendToForOpYield(forOp, newYieldOperands); + tt::CoarseSchedule coarseSchedule(numStages); + coarseSchedule.deSerialize(forOp); + scheduleDependencies(forOp, coarseSchedule); + coarseSchedule.serialize(forOp); + + // Make sure all ops have attributes. + for (Operation &op : forOp.getBody()->without_terminator()) { + assert(op.hasAttr(mlir::triton::kLoopStageAttrName) && + op.hasAttr(mlir::triton::kLoopClusterAttrName)); + } return allocs; } static void invalidateBarriers(OpBuilder &builder, SmallVector &barriers) { Attribute sharedMemorySpace = - triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); + ttg::SharedMemorySpaceAttr::get(builder.getContext()); for (Value barrier : barriers) { - int numBarriers = cast(barrier.getType()).getShape()[0]; + auto allocTy = cast(barrier.getType()); + int numBarriers = allocTy.getShape()[0]; for (int i = 0; i < numBarriers; i++) { Value idx = builder.create(barrier.getLoc(), i, 32); - tt::MemDescType barrierTy = tt::MemDescType::get( - {1}, builder.getI64Type(), - cast(barrier.getType()).getEncoding(), - sharedMemorySpace, - /*mutableMemory=*/true); + ttg::MemDescType barrierTy = ttg::MemDescType::get( + {1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/allocTy.getShape()); Value barrierView = builder.create( barrier.getLoc(), barrierTy, barrier, idx); builder.create(barrier.getLoc(), barrierView); @@ -1066,59 +856,34 @@ static void invalidateBarriers(OpBuilder &builder, bool mlir::triton::preProcessLoopAndGetSchedule( scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { - // Schedule the loads and root ops (dot ops) in the loop. This will give us - // a scaffold for the final schedule. - DenseSet rootUsers; - tt::CoarseSchedule coarseSchedule(numStages); + + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + // Check which loads are good for pipelining, and assign them + // memory layouts. llvm::MapVector loadToInfo = - scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + assignMemoryLayouts(forOp, axisInfoAnalysis); if (loadToInfo.empty()) return false; - LLVM_DEBUG({ - LDBG("Coarse schedule loads only:"); - coarseSchedule.dump(); - }); - - tt::CoarseSchedule::Cluster afterPrologue = - schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with prologue and epilogue:"); - coarseSchedule.dump(); - }); + // Distance from the load to the use. + for (auto &[loadOp, info] : loadToInfo) { + auto *use = getFirstUseOfPipelinedLoad(loadOp); + auto [stage, _] = tt::getStageCluster(loadOp); + auto [stageUse, t_] = tt::getStageCluster(use); + loadToInfo[loadOp].distToUse = stageUse - stage; + } SmallVector barriers; // Convert the loads into async loads and create the allocs. SmallVector allocs = - createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); - - LLVM_DEBUG({ - LDBG("Coarse schedule with async loads:"); - coarseSchedule.dump(); - }); - - scheduleDependencies(forOp, coarseSchedule, numStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with dependencies:"); - coarseSchedule.dump(); - }); - - scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with dist 1:"); - coarseSchedule.dump(); - }); - - scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); - LLVM_DEBUG({ - LDBG("Final coarse schedule:"); - coarseSchedule.dump(); - }); + createAsyncOps(forOp, loadToInfo, barriers, numStages); + LDBG("after lowering: " << forOp->getParentOfType()); // Create the final schedule for the kernel loop. This will dictate the // stages and order of operations to the pipeline expander. std::vector> schedule = - coarseSchedule.createFinalSchedule(forOp); + getFinalSchedule(forOp, numStages); // Fill out the pipeline options. options.getScheduleFn = @@ -1132,6 +897,13 @@ bool mlir::triton::preProcessLoopAndGetSchedule( options.annotateFn = [](Operation *op, mlir::triton::PipeliningOption::PipelinerPart part, unsigned iteration) {}; + + // Clean up the attributes. + for (Operation &op : forOp.getBody()->without_terminator()) { + op.removeAttr(mlir::triton::kLoopStageAttrName); + op.removeAttr(mlir::triton::kLoopClusterAttrName); + } + // Insert a wait 0 after the loop OpBuilder builder(forOp); builder.setInsertionPointAfter(forOp); @@ -1319,7 +1091,7 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, for (ttng::WarpGroupDotOp dot : asyncDots) { for (Value operand : dot.getOperands()) { - if (isa(operand.getType())) { + if (isa(operand.getType())) { newOperands.insert(operand); } } @@ -1337,12 +1109,12 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, }; for (int i = 0; i < origNumOperands; i++) { Value operand = wait.getResult(i); - if (!isa(operand.getType())) + if (!isa(operand.getType())) operand.replaceAllUsesWith(newWait.getResult(i)); } for (int i = origNumOperands; i < newOperands.size(); i++) { Value operand = newWait.getOperand(i); - if (!isa(operand.getType())) + if (!isa(operand.getType())) operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); } wait->erase(); @@ -1362,6 +1134,15 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, // // 1. All operands that touch shared memory are multi-buffered, i.e. can't read // an incomplete value while it's being written asynchronously by a load. +// 1a. If operand A is in registers, these registers cannot be updated +// inside +// the loop. +// **Exception** if the operand is produced by a preceding WGMMA, +// then this op can be properly async. Either the f16 shortcut is +// possible and the WGMMA's can run back-to-back (see rule 3 below), or +// elementwise truncate is needed, in which case the preceding WGMMA is +// not async and a WarpGroupDotWait is inserted right after, which +// guarantees exclusive access to the operand registers. // // 2. If the dot is used by any op in the loop, it must be used under an `if`, // and will be synced with a `wait 0` at the beginning of the `if` block. @@ -1396,15 +1177,23 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, // Rule 1: All shmem operands are multi-buffered. auto checkOperand = [&](Value operand) { if (!isa( - cast(operand.getType()).getEncoding())) { - return true; + cast(operand.getType()).getEncoding())) { + // Rule 1a: Register operands must not be modified within the loop. + // First, check for chained WGMMA as an exception. + if (auto cvt = dyn_cast(operand.getDefiningOp())) { + return isa( + cvt.getSrc().getType().getEncoding()); + } + // And then, do a stricter-than-necessary check for now, that the operand + // is defined outside the loop. + return forOp.isDefinedOutsideOfLoop(operand); } // If it's a shmem operand, it must either be defined outside the loop, or // come from an MemDescSubview op. Only ConvertLayout and Trans ops are // allowed in between. Value transitiveOperand = operand; - while (isa_and_nonnull( + while (isa_and_nonnull( transitiveOperand.getDefiningOp()) || isa(transitiveOperand)) { auto blockArg = dyn_cast(transitiveOperand); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 1a3162f17b98..aab560770720 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -13,6 +13,20 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; +bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + }); +} + +bool mlir::triton::isOuterLoop(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) { + return isa(op); + }); +} + // Combine the current mask with the given predicate. static Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, Value pred) { @@ -80,6 +94,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, storeOp.getMaskMutable().assign(mask); return op; } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } assert("don't know how to predicate this op" && false); return op; @@ -136,7 +157,8 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, // TODO: can we use an early_inc iterator? for (OpOperand &use : oldUse->getUses()) { // Non-subview/trans ops will be replaced by `val`. - if (!isa(use.getOwner())) { + if (!isa( + use.getOwner())) { operandsToReplace.push_back(&use); continue; } @@ -146,17 +168,19 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, builder.setInsertionPoint(user); Value newVal; if (auto subview = dyn_cast(user)) { - triton::MemDescType oldType = subview.getType(); + triton::gpu::MemDescType oldType = subview.getType(); bool isMutable = - cast(val.getType()).getMutableMemory(); - Type newDstType = triton::MemDescType::get( + cast(val.getType()).getMutableMemory(); + Type newDstType = triton::gpu::MemDescType::get( oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), oldType.getMemorySpace(), isMutable); newVal = builder.create( subview.getLoc(), newDstType, val, subview.getOffsets()); - } else if (auto trans = dyn_cast(user)) { - newVal = builder.create(trans.getLoc(), val, - trans.getOrderAttr()); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + } else if (auto trans = dyn_cast(user)) { + newVal = builder.create(trans.getLoc(), val, + trans.getOrder()); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); } assert(newVal); replaceUsesAndPropagateType(builder, user, newVal); @@ -173,3 +197,40 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, for (Operation *op : opsToDelete) op->erase(); } + +std::pair mlir::triton::getStageCluster(Operation *op) { + auto stage = cast(op->getAttr(mlir::triton::kLoopStageAttrName)) + .getValue() + .getSExtValue(); + auto clusterId = + cast(op->getAttr(mlir::triton::kLoopClusterAttrName)) + .getValue() + .getSExtValue(); + return std::make_pair(stage, clusterId); +} + +void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) { + auto ctx = op->getContext(); + op->setAttr(mlir::triton::kLoopStageAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), stage)); + op->setAttr(mlir::triton::kLoopClusterAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), cluster)); +} + +std::pair mlir::triton::getMinMaxCluster(scf::ForOp &forOp) { + int minClusterId = -1, maxClusterId = -1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName) || + !op.hasAttr(mlir::triton::kLoopClusterAttrName)) + continue; + auto [_, cluster] = getStageCluster(&op); + if (maxClusterId < 0) { + minClusterId = cluster; + maxClusterId = cluster; + continue; + } + maxClusterId = cluster > maxClusterId ? cluster : maxClusterId; + minClusterId = cluster < minClusterId ? cluster : minClusterId; + } + return std::make_pair(minClusterId, maxClusterId); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 1116b70a0262..bfb31a3e8d6a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -90,3 +90,43 @@ void tt::CoarseSchedule::dump() { } } } + +// Set based on CoarseSchedule. +void tt::CoarseSchedule::serialize(scf::ForOp &forOp) { + for (auto [op, stage, cluster] : getOpsInOrder(forOp)) { + tt::setStageCluster(op, stage, *cluster); + } +} + +// Create a CoarseSchedule based on forOp's . +void tt::CoarseSchedule::deSerialize(scf::ForOp &forOp) { + auto [minClusterId, maxClusterId] = tt::getMinMaxCluster(forOp); + + DenseMap clustersMap; + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) + continue; + auto [stage, clusterId] = tt::getStageCluster(&op); + insert(&op, stage, clustersMap[clusterId]); + } +} + +// TODO: Should this be moved somewhere else? +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + int numStages = schedule.numStages; + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 8766e82b9f15..3361087a9c7d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -37,22 +37,10 @@ namespace gpu { static bool preCondition(scf::ForOp forOp) { // Skip loop with distance > 1 for now. // TODO: relax the constraint in the expander. - if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), - [](Value operand) { - Operation *def = operand.getDefiningOp(); - return !def; - })) + if (loopHasDistGreaterThanOne(forOp)) return false; // Don't pipeline outer loops. - if (forOp - ->walk([&](Operation *op) { - if (forOp.getOperation() == op) - return WalkResult::advance(); - if (isa(op)) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) + if (isOuterLoop(forOp)) return false; return true; } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 7985d25b9097..b24ac95387c8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -42,8 +42,8 @@ static Value createAlloc(scf::ForOp &forOp, Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); Type memdescType = - tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, - sharedMemorySpace, /*mutableMemory*/ true); + ttg::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); Value alloc = builder.create(storeOp->getLoc(), memdescType, Value()); return alloc; @@ -63,8 +63,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, builder.create(loc, 0); builder.create(loc, storeOp.getSrc(), alloc); builder.create(loc, false); + Value tmaPtr = builder.create( + loc, storeOp.getDesc()); builder.create( - loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); + loc, tmaPtr, storeOp.getIndices(), alloc); storeOp->erase(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp new file mode 100644 index 000000000000..ae3f3a97f9d4 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp @@ -0,0 +1,43 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINEASSIGNLATENCIES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static const char *kLatencyAttrName = "tt.latency"; + +struct TestPipelineAssignLatencies + : public impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies> { + using impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies>::TritonGPUTestPipelineAssignLatenciesBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + DenseMap opLatencies = assignLatencies(m, numStages); + + for (auto [op, latency] : opLatencies) { + op->setAttr( + kLatencyAttrName, + IntegerAttr::get(IntegerType::get(m.getContext(), 32), latency)); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp new file mode 100644 index 000000000000..54956c7177ca --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp @@ -0,0 +1,54 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINESCHEDULELOOP +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static const char *kLatencyAttrName = "tt.latency"; + +struct TestPipelineScheduleLoop + : public impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop> { + using impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop>::TritonGPUTestPipelineScheduleLoopBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + DenseMap opLatencies; + + // Deserialize latencies from the IR. + m.walk([&](Operation *op) { + if (op->hasAttr(kLatencyAttrName)) { + int latency = + mlir::cast(op->getAttr(kLatencyAttrName)).getInt(); + op->removeAttr(kLatencyAttrName); + opLatencies[op] = latency; + } + }); + + SmallVector loops; + m.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + for (auto forOp : loops) { + scheduleLoop(forOp, opLatencies); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 2cbc00142b42..c11f2f8e5ee7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -15,12 +15,12 @@ // // %a: tensor<128x32xf16, #enc> // %a_tmp = tensor.subview %a[0, 0] [128, 16] -// %a_prefetch = triton_gpu.local_load %a_tmp +// %a_prefetch = ttg.local_load %a_tmp // scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) // { // %x = tt.dot %a_prefetch_arg, %b, %c // %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] -// %a_prefetch_next = triton_gpu.local_load %a_tmp_rem +// %a_prefetch_next = ttg.local_load %a_tmp_rem // ... // scf.yield %next_a, ..., %a_prefetch_next // } @@ -114,7 +114,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, std::optional offsetK, std::optional shapeK) { // opIdx: 0 => a, 1 => b - auto type = cast(v.getType()); + auto type = cast(v.getType()); SmallVector shape{type.getShape().begin(), type.getShape().end()}; SmallVector offset{0, 0}; Type elementType = type.getElementType(); @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, builder.create(v.getLoc(), off, 32)); Value newSmem = builder.create( v.getLoc(), - triton::MemDescType::get(shape, elementType, type.getEncoding(), - type.getMemorySpace()), + triton::gpu::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()), v, offsetsVal); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index b1e296c1bbe4..af756c6d83e9 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -58,7 +58,7 @@ class TritonGPUReduceDataDuplicationPass } auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); - auto tmpType = triton::MemDescType::get( + auto tmpType = triton::gpu::MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index cee1ae84ef59..9c05ae5bb46d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -282,6 +282,7 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, setEncoding(user->getResults(), info, changed, user); continue; } + // TODO(jeff): Propagate tt.gather indices layout to dst. } return changed; } @@ -709,6 +710,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } return newOp; } + // TODO(jeff): Handle tt.gather once it supports layout propagation. llvm::report_fatal_error("unexpected op in rewrite"); return nullptr; } @@ -970,7 +972,9 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf + if (isa(targetType.getEncoding())) return; Value oldV = convertOp->getOperand(0); LDBG("check backward remat with source " << oldV << " encoding " @@ -1012,8 +1016,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa(targetType.getEncoding())) + if (mlir::isa( + targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { @@ -1024,7 +1031,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( if (auto fpToFpOp = dyn_cast(op)) { auto srcType = cast(fpToFpOp.getOperand().getType()); return getElementBitWidth(srcType) < - getElementBitWidth(fpToFpOp.getType()); + getElementBitWidth(cast(fpToFpOp.getType())); } return false; }; diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp index bff277c59314..540cf081c53a 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -100,7 +100,7 @@ class TritonGPUReorderInstructionsPass }); // Move transpositions just after their definition opToMove.clear(); - m.walk([&](triton::TransOp op) { + m.walk([&](triton::TransposeOpInterface op) { Operation *argOp = op.getSrc().getDefiningOp(); if (!argOp) return; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4ef9d1cd1d11..ec3517ef139f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -19,6 +19,7 @@ #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +namespace ttg = mlir::triton::gpu; namespace mlir { using namespace triton; @@ -378,13 +379,13 @@ inferTransOpDstEncoding(Attribute srcEnc, ArrayRef order) { return std::nullopt; } -static std::optional inferDstEncoding(triton::TransOp op, - Attribute encoding) { +static std::optional +inferDstEncoding(triton::TransposeOpInterface op, Attribute encoding) { return inferTransOpDstEncoding(encoding, op.getOrder()); } -static std::optional inferSrcEncoding(triton::TransOp op, - Attribute encoding) { +static std::optional +inferSrcEncoding(triton::TransposeOpInterface op, Attribute encoding) { // We want to solve for srcEnc in // transpose(srcEnc, order) -> dstEnc. // Given the identity @@ -467,10 +468,12 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { return inferSrcEncoding(join, encoding); if (auto split = dyn_cast(op)) return inferSrcEncoding(split, encoding); - if (auto trans = dyn_cast(op)) + if (auto trans = dyn_cast(op)) return inferSrcEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferSrcEncoding(reshape, encoding); + // TODO(jeff): Handle progagating tt.gather indices -> dst layout. + // This requires updating the API to specify the exact operands and results. return std::nullopt; } @@ -494,10 +497,11 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return inferDstEncoding(join, encoding); if (auto split = dyn_cast(op)) return inferDstEncoding(split, encoding); - if (auto trans = dyn_cast(op)) + if (auto trans = dyn_cast(op)) return inferDstEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferDstEncoding(reshape, encoding); + // TODO(jeff): Handle progagating tt.gather indices -> dst layout. return std::nullopt; } @@ -562,7 +566,8 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { } return isa(op); + triton::gpu::LocalAllocOp, triton::gpu::LocalLoadOp, + triton::gpu::LocalStoreOp>(op); } scf::ForOp replaceForOpWithNewSignature( @@ -930,6 +935,91 @@ int getNVIDIAComputeCapability(Operation *module) { return computeCapability; } +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are imcompatible shared +// encodings, set incompatible to true. +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { + ttg::SharedEncodingAttr attr; + incompatible = false; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()) + .getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (attr != nullptr && attr != tempAttr) { + incompatible = true; + return std::nullopt; + } + attr = tempAttr; + } + return attr; +} + +MMALoadType getMMALoadType(Operation *loadOp) { + if (!loadOp->hasOneUse()) + return MMALoadType::DoNotPipeline; + + if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { + auto sharedEnc = + cast(alloc.getType().getEncoding()); + + if (!sharedEnc.getHasLeadingOffset()) + return MMALoadType::DoNotPipeline; + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = cast(loadOp->getResultTypes()[0]); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder ? MMALoadType::SharedV3 + : MMALoadType::DoNotPipeline; + } else if (auto cvt = + dyn_cast(*loadOp->getUsers().begin())) { + auto resTy = dyn_cast(cvt->getResultTypes()[0]); + if (!resTy) { + return MMALoadType::DoNotPipeline; + } + + if (isa(resTy.getEncoding())) { + return MMALoadType::Registers; + } + + return MMALoadType::DoNotPipeline; + } else { + return MMALoadType::DoNotPipeline; + } +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and @@ -1067,4 +1157,40 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +std::optional +getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, int elemBitWidth) { + StringAttr kBlock = StringAttr::get(ctx, ("block")); + int rank = shape.size(); + + std::optional regLayout = + triton::gpu::toLinearLayout(shape, srcEnc); + std::optional sharedLayout = + triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth); + if (!regLayout.has_value() || !sharedLayout.has_value()) { + return std::nullopt; + } + auto sharedOrder = triton::gpu::getOrder(dstEnc); + + // sharedLayout's in-dims are currently (offset, block). Reshape to + // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional + // shmem strides. (The offsetX's appear in minor-to-major order.) + auto sharedLegacy = cast(dstEnc); + SmallVector> multiDimSharedSize; + for (int i = 0; i < rank; i++) { + int dim = sharedOrder[i]; + int64_t size = std::max( + int64_t{1}, + shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); + multiDimSharedSize.push_back( + {StringAttr::get(ctx, ("offset" + std::to_string(dim))), size}); + } + multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); + sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); + + // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, + // ..., offsetXN, block), where the offsetX's are in minor-to-major order. + return regLayout->invertAndCompose(*sharedLayout); +} + } // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 37c69eef8adb..942eb5423dba 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -42,8 +42,10 @@ mlir::LogicalResult WarpGroupDotOp::inferReturnTypes( inferredReturnTypes.push_back(accTy); // verify encodings - auto aEnc = cast(operands[0].getType()).getEncoding(); - auto bEnc = cast(operands[1].getType()).getEncoding(); + auto aEnc = + cast(operands[0].getType()).getEncoding(); + auto bEnc = + cast(operands[1].getType()).getEncoding(); auto retEnc = accTy.getEncoding(); if (aEnc) { assert(bEnc); @@ -62,10 +64,10 @@ void WarpGroupDotOp::getEffects( &effects) { auto &a = getAMutable(); auto &b = getBMutable(); - if (isa(a.get().getType())) + if (isa(a.get().getType())) effects.emplace_back(MemoryEffects::Read::get(), &a, mlir::triton::gpu::SharedMemory::get()); - if (isa(b.get().getType())) + if (isa(b.get().getType())) effects.emplace_back(MemoryEffects::Read::get(), &b, mlir::triton::gpu::SharedMemory::get()); } @@ -73,11 +75,12 @@ void WarpGroupDotOp::getEffects( bool WarpGroupDotOp::needsPartialAccumulator() { const auto &a = getA(); const auto &d = getD(); - auto aTensorTy = cast(a.getType()); - auto aElTy = cast(a.getType()).getElementType(); + auto aTensorTy = cast(a.getType()); + auto aElTy = cast(a.getType()).getElementType(); bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); - bool accFP32 = cast(d.getType()).getElementType().isF32(); + bool accFP32 = + cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; } @@ -93,7 +96,8 @@ LogicalResult WarpGroupDotWaitOp::inferReturnTypes( return mlir::success(); } -static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) { +static LogicalResult +verifyBarrierType(Operation *op, mlir::triton::gpu::MemDescType barrierType) { if (!barrierType.getElementType().isInteger(64) || barrierType.getShape() != ArrayRef({1})) return op->emitOpError( @@ -160,6 +164,18 @@ void WaitBarrierOp::getEffects( mlir::triton::gpu::SharedMemory::get()); } +// -- TensorDescToTMAPtrOp -- +LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op, + PatternRewriter &rewriter) { + // tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr + if (auto reinterpret = + op.getDesc().getDefiningOp()) { + rewriter.replaceOp(op, reinterpret.getRawDesc()); + return success(); + } + return failure(); +} + // -- AsyncTMACopyGlobalToLocalOp -- LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { if (failed(verifyBarrierType(*this, getBarrier().getType()))) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 0938432c7e58..cb9ae9dd0f3c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -3,8 +3,10 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -58,8 +60,10 @@ class TMALoadLowering : public OpRewritePattern { Value pred = rewriter.create(loc, 1, 1); rewriter.create(loc, barrierAlloc, sizeInBytes, pred); + Value tmaPtr = rewriter.create( + loc, op.getDesc()); rewriter.create( - loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred); + loc, tmaPtr, op.getIndices(), barrierAlloc, alloc, pred); Value phase = rewriter.create(loc, 0, 32); rewriter.create(loc, barrierAlloc, phase); rewriter.create(loc, barrierAlloc); @@ -93,14 +97,114 @@ class TMAStoreLowering encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, op.getSrc()); rewriter.create(loc, false); + Value tmaPtr = rewriter.create( + loc, op.getDesc()); rewriter.create( - loc, op.getDescPtr(), op.getIndices(), alloc); + loc, tmaPtr, op.getIndices(), alloc); rewriter.create(loc, 0); rewriter.eraseOp(op); return success(); } }; +class TMACreateDescLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MakeTensorDescOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + constexpr auto kTmaNbytes = 128; + constexpr auto kTmaAlignment = 128; + auto alloc = rewriter.create( + loc, getPointerType(rewriter.getI8Type()), kTmaNbytes, kTmaAlignment); + auto mkI32Constant = [&](int32_t val) { + return rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + + int32_t contig_dim_size = op.getTensorShape().back(); + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; + if (contig_dim_size_in_bytes > 128) { + contig_dim_size = 128 / elemSize; + } + llvm::SmallVector boxDim; + boxDim.push_back(mkI32Constant(contig_dim_size)); + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); + } + + int32_t swizzle_mode; + if (contig_dim_size_in_bytes >= 128) { + swizzle_mode = 3; + } else if (contig_dim_size_in_bytes == 64) { + swizzle_mode = 2; + } else if (contig_dim_size_in_bytes == 32) { + swizzle_mode = 1; + } else { + op->emitError() + << "contiguous box dimension must be at least 32 bytes but got " + << contig_dim_size_in_bytes; + return failure(); + } + + Value elemSizeVal = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); + Value globalStride = + rewriter.create(loc, op.getStrides()[0], elemSizeVal); + // TODO: Workaround for ptxas bug, remove when we update ptxas + Value four = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(4)); + globalStride = rewriter.create(loc, globalStride, four); + + int elemTypeEnum; + switch (elemSize) { + case 1: { + elemTypeEnum = 0; + break; + } + case 2: { + elemTypeEnum = 1; + break; + } + case 4: { + elemTypeEnum = 2; + break; + } + default: { + op->emitError() + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; + return failure(); + } + } + + auto one = mkI32Constant(1); + rewriter.create( + loc, + /*desc_ptr=*/alloc.getResult(), + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, + /*global_stride=*/ValueRange{globalStride}, + /*element_strides=*/ValueRange{one, one}, + /*elem_type*/ rewriter.getI32IntegerAttr(elemTypeEnum), + /*interleave_layout*/ rewriter.getI32IntegerAttr(0), + /*swizzle_mode=*/rewriter.getI32IntegerAttr(swizzle_mode), + /*fill_mode=*/rewriter.getI32IntegerAttr(0)); + rewriter.create( + loc, alloc.getResult()); + auto newDesc = rewriter.create( + loc, op.getType(), alloc.getResult()); + rewriter.replaceOp(op, newDesc); + return success(); + } +}; + class TritonNvidiaGPUTMALoweringPass : public TritonNvidiaGPUTMALoweringPassBase< TritonNvidiaGPUTMALoweringPass> { @@ -110,7 +214,8 @@ class TritonNvidiaGPUTMALoweringPass ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); - patterns.add(context); + patterns.add( + context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } diff --git a/lib/Instrumentation/CMakeLists.txt b/lib/Instrumentation/CMakeLists.txt new file mode 100644 index 000000000000..cd437d53b307 --- /dev/null +++ b/lib/Instrumentation/CMakeLists.txt @@ -0,0 +1,40 @@ +set(GPU_INSTRUMENTATION_PASSES + PrintLoadStoreMemSpaces + ) + +set(PrintLoadStoreMemSpaces_SOURCES + PrintLoadStoreMemSpaces.cpp + ) + + +foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) + add_library( + ${plugin} + SHARED + ${${plugin}_SOURCES} + ) + + target_link_libraries( + ${plugin} + PRIVATE + LLVMCore + LLVMSupport + LLVMTransformUtils + "$<$:-undefined dynamic_lookup>" + ) + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti) +endforeach() diff --git a/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp b/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp new file mode 100644 index 000000000000..c243fc149d5e --- /dev/null +++ b/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp @@ -0,0 +1,101 @@ +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include + +using namespace llvm; + +namespace { + +struct LoadStoreMemSpace : public PassInfoMixin { + PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) { + bool modifiedCodeGen = runOnModule(module); + + return (modifiedCodeGen ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all()); + } + bool runOnModule(llvm::Module &module); + // isRequired being set to true keeps this pass from being skipped + // if it has the optnone LLVM attribute + static bool isRequired() { return true; } +}; + +} // end anonymous namespace + +std::map AddrSpaceMap = { + {0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}}; + +std::map LocationCounterSourceMap; + +std::string LoadOrStoreMap(const BasicBlock::iterator &I) { + if (LoadInst *LI = dyn_cast(I)) + return "LOAD"; + else if (StoreInst *SI = dyn_cast(I)) + return "STORE"; + else + throw std::runtime_error("Error: unknown operation type"); +} +template +void InstrumentationFunction(const BasicBlock::iterator &I, const Function &F, + const llvm::Module &M, uint32_t &LocationCounter) { + auto LSI = dyn_cast(I); + if (not LSI) + return; + Value *Op = LSI->getPointerOperand()->stripPointerCasts(); + uint32_t AddrSpace = cast(Op->getType())->getAddressSpace(); + DILocation *DL = dyn_cast(I)->getDebugLoc(); + + std::string SourceAndAddrSpaceInfo = + (F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) + + ":" + Twine(DL->getColumn())) + .str() + + " " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I); + + if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) == + LocationCounterSourceMap.end()) { + errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n"; + LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter; + LocationCounter++; + } +} + +bool LoadStoreMemSpace::runOnModule(Module &M) { + bool ModifiedCodeGen = false; + uint32_t LocationCounter = 0; + for (auto &F : M) { + if (F.isIntrinsic()) + continue; + StringRef functionName = F.getName(); + if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel || + functionName.contains("kernel")) { + for (Function::iterator BB = F.begin(); BB != F.end(); BB++) { + for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) { + if (LoadInst *LI = dyn_cast(I)) { + InstrumentationFunction(I, F, M, LocationCounter); + } else if (StoreInst *SI = dyn_cast(I)) { + InstrumentationFunction(I, F, M, LocationCounter); + } + } + } + } + } + return ModifiedCodeGen; +} + +PassPluginLibraryInfo getPassPluginInfo() { + const auto callback = [](PassBuilder &PB) { + PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) { + MPM.addPass(LoadStoreMemSpace()); + return true; + }); + }; + + return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING, + callback}; +}; + +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() { + return getPassPluginInfo(); +} diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 460792439dde..0ab563908a60 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -112,30 +112,6 @@ std::unique_ptr getMatrix(const LinearLayout &layout) { return m; } -// Get a matrix for `layout` with its codomain expanded so it's injective, i.e. -// each input element maps to a unique output element. We do this by finding -// columns that are equal to 0 and adding a new row with a 1 in that column. -std::tuple, int /*numRows*/, int /*numCols*/> -getInjectiveMat(const LinearLayout &layout) { - int numRows = layout.getTotalOutDimSizeLog2(); - int numCols = layout.getTotalInDimSizeLog2(); - std::unique_ptr mat = getMatrix(layout); - - // Bits of mat or-reduced along the columns (so there's just one row). - uint64_t colBits = 0; - for (int r = 0; r < numRows; r++) { - colBits |= mat[r]; - } - auto expanded = std::unique_ptr(new uint64_t[numRows + numCols]); - std::memcpy(expanded.get(), mat.get(), numRows * sizeof(uint64_t)); - for (int c = 0; c < numCols; c++) { - if ((colBits & (1 << c)) == 0) { - expanded[numRows++] = (1 << c); - } - } - return std::make_tuple(std::move(expanded), numRows, numCols); -} - // Compute the rank of the matrix formed by taking the bases for the given // outDim as columns. In other words, finds the number of linearly-independent // bases for this output dimension. @@ -212,42 +188,6 @@ void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { "\nb: " + triton::join(bDims, ", ")); } } - -void eraseEmptyInOutDims(BasesT &bases, - llvm::MapVector &outDims) { - // Erase empty out-dims. - SmallVector emptyOutDims; - for (auto [i, outDim] : llvm::enumerate( - llvm::to_vector_of(llvm::make_first_range(outDims)))) { - if (outDims[outDim] == 1) { - emptyOutDims.push_back(i); - outDims.erase(outDim); - } - } - if (outDims.empty()) { - bases.clear(); - return; - } - - for (auto &[inDim, inDimBases] : bases) { - for (auto &basis : inDimBases) { - // Erase the basis elements corresponding to the empty out-dims. - for (int i : llvm::reverse(emptyOutDims)) { - basis.erase(basis.begin() + i); - } - } - } - - // Erase empty in-dims. - // TODO: This needs a test-case. - for (StringAttr inDim : - llvm::to_vector_of(llvm::make_first_range(bases))) { - if (bases[inDim].empty()) { - bases.erase(inDim); - } - } -} - } // anonymous namespace /*static*/ std::optional @@ -657,153 +597,62 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) { inner.isSurjective() && outer.isSurjective()); } -std::optional -LinearLayout::divideRight(const LinearLayout &divisor) const { - assertCommonDimsSameOrder(getOutDimNames(), divisor.getOutDimNames()); - assertCommonDimsSameOrder(getInDimNames(), divisor.getInDimNames()); - - // Strip off the top N bases for each input dimension of divisor. This - // gives a candidate quotient. Then check if quotient * divisor equals - // `this`. - BasesT newBases = bases; - for (StringAttr inDim : divisor.getInDimNames()) { - if (getInDimSizeLog2(inDim) < divisor.getInDimSizeLog2(inDim)) { - return std::nullopt; - } - auto &newInDimBases = newBases[inDim]; - newInDimBases.resize(newInDimBases.size() - - divisor.getInDimSizeLog2(inDim)); - } - - // Check if the size of the new out-dims are large enough. - // If yes, we can divide the out-dims. - // If no, we return nullopt to indicate that the division is not possible. - llvm::MapVector newOutDims = outDims; - for (const auto [outDimName, outDimSize] : divisor.outDims) { - if (newOutDims[outDimName] < outDimSize) { - return std::nullopt; +bool LinearLayout::isTrivialOver(ArrayRef dimNames) const { + for (StringAttr dim : dimNames) { + if (!llvm::is_contained(getInDimNames(), dim) && + !llvm::is_contained(getOutDimNames(), dim)) { + return false; } - newOutDims[outDimName] /= outDimSize; - } - - LDBG("Checking candidate_quotient * divisor == *this"); - LDBG("this:" << *this); - LDBG("divisor:" << divisor); - LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second.size()); - })); - LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second); - })); - std::optional candidateQuotient = LinearLayout::tryCreate( - std::move(newBases), std::move(newOutDims.takeVector()), - /*requireSurjective=*/false); - LDBG("candidate_quotient:" << candidateQuotient); - if (!candidateQuotient.has_value()) { - LDBG("candidate quotient failed invariant checks"); - return std::nullopt; } - LDBG("*candidate_quotient * divisor=" << *candidateQuotient * divisor); - if (*candidateQuotient * divisor != *this) { - LDBG("candidate quotient failed invariant checks"); - return std::nullopt; - } - - // Now that we have a candidate quotient, we need to eliminate any empty - // dimensions from the candidate quotient but still ensure that - // quotient * divisor == *this. - newBases = candidateQuotient->bases; - newOutDims = candidateQuotient->outDims; - // We only remove the trailing empty output dimensions from `quotient`. - // - // In the multiplication `quotient * divisor == result`, the output dimensions - // of `quotient` always come before those of `divisor` in `result`. Removing - // any non-trailing empty dimensions from `quotient` would change the - // order of the output dimensions in `result`. - // - // The following loop iterates through the output dimensions of `result` from - // right to left. During the iteration, the following conditions are checked: - // - // 1. If an output dimension exists only in `divisor` and not in `quotient`, - // the loop continues. - // 2. If an output dimension exists only in `quotient` and not in `divisor`, - // we stop the loop. - // 3. If an output dimension exists in both `quotient` and `divisor`, it may - // be removed, but only if it is a size-1 dimension and meets one of the - // following conditions: - // - The dimension immediately following it in `quotient` has already been - // removed. - // - It is the last dimension of `quotient`. - // Otherwise, removing this dimension could alter the structure of `result`. - // - // Consider the quotient l = o / r, where: - // out-dims(o) = ["out0", "out1", "out2", "out3"] - // out-dims(r) = ["out1", "out3"] - // - // Only "out1" is a size-1 dimension. If we remove "out1" from o, the - // resulting output dimensions would be: - // out-dims(l) = ["out0", "out2", "out3"] - // - // Performing the multiplication l * r results in: - // out-dims(l * r) = ["out0", "out2", "out3"] * ["out1", "out3"] = ["out0", - // "out2", "out3", "out1"] - // This outcome does not match the original out-dims(o). - // - // However, if we remove only "out3" from o, we get: - // out-dims(l) = ["out0", "out1", "out2"] - // - // Then, performing the multiplication l * r yields: - // out-dims(l * r) = ["out0", "out1", "out2"] * ["out1", "out3"] = ["out0", - // "out1", "out2", "out3"] - // This result matches the original out-dims(o). - llvm::SmallVector emptyOutDimIndices; - for (const auto [outDimName, outDimSize] : llvm::reverse(outDims)) { - if (newOutDims.contains(outDimName) && !divisor.hasOutDim(outDimName)) { - break; - } - if (newOutDims.contains(outDimName) && divisor.hasOutDim(outDimName) && - candidateQuotient->getOutDimSize(outDimName) == 1) { - auto lastOutDimName = newOutDims.rbegin()->first; - if (outDimName != lastOutDimName) { - break; + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); } - emptyOutDimIndices.push_back(getOutDimIndex(outDimName)); - newOutDims.erase(outDimName); } + return remainingDimNames; + }; + SmallVector remainingInDimNames = + getRemainingDimNames(getInDimNames()); + SmallVector remainingOutDimNames = + getRemainingDimNames(getOutDimNames()); + + // Think of this as a block-matrix multiplying a vector: + // [[A, B], * [v_1, + // [C, D]] v_2] + // where v_2 is the dimNames and v_1 is the remainingInDimNames + // We can quotient out dimNames iff they don't affect the remainingInDimNames + // in the result. In other words, we want to check that B is zero, and C is + // zero, and D is the identity + return squareSublayoutIsIdentity(dimNames) && + sublayoutIsZero(remainingInDimNames, dimNames) && + sublayoutIsZero(dimNames, remainingOutDimNames); +} + +std::optional +LinearLayout::quotient(ArrayRef dimNames) const { + if (!isTrivialOver(dimNames)) { + return std::nullopt; } - // Erase the basis elements corresponding to the empty out-dims. - for (auto &[inDim, inDimBases] : newBases) { - for (auto &basis : inDimBases) { - for (int i : emptyOutDimIndices) { - basis.erase(basis.begin() + i); + // This should probably be even less general, where we ask inDimNames == + // outDimNames + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); } } - } + return remainingDimNames; + }; - // Erase trailing empty in-dims. - for (auto inDimName : llvm::reverse(getInDimNames())) { - if (newBases[inDimName].empty() && divisor.hasInDim(inDimName)) { - newBases.erase(inDimName); - } else { - break; - } - } + SmallVector inDimNames = getRemainingDimNames(getInDimNames()); + SmallVector outDimNames = getRemainingDimNames(getOutDimNames()); - LDBG("Eliminated empty dims from candidate_quotient"); - LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second.size()); - })); - LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second); - })); - auto quotient = LinearLayout::tryCreate(std::move(newBases), - std::move(newOutDims).takeVector(), - /*requireSurjective=*/false); - LDBG("quotient:" << quotient); - assert(quotient.has_value()); - return quotient; + return sublayout(inDimNames, outDimNames); } LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, @@ -813,10 +662,10 @@ LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, SmallDenseSet inDimSet(inDimNames.begin(), inDimNames.end()); SmallDenseSet outDimSet(outDimNames.begin(), outDimNames.end()); - SmallDenseSet outDimIndicesToKeep; + SmallVector outDimIndicesToKeep; for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) { if (outDimSet.contains(outDim)) { - outDimIndicesToKeep.insert(i); + outDimIndicesToKeep.push_back(i); } } BasesT newBases; @@ -856,13 +705,22 @@ bool LinearLayout::sublayoutIsZero(ArrayRef inDimNames, return true; } -bool LinearLayout::sublayoutIsIdentity(ArrayRef inDimNames, - ArrayRef outDimNames) const { - LinearLayout sl = - sublayout(inDimNames, outDimNames).flattenIns().flattenOuts(); - if (sl.getNumInDims() == 0 || sl.getNumOutDims() == 0) { +bool LinearLayout::squareSublayoutIsIdentity( + ArrayRef dimNames) const { + // The empty layout is the identity + if (dimNames.size() == 0) { return true; } + // Check that the input-output sizes are the same + LinearLayout sl = sublayout(dimNames, dimNames); + for (StringAttr dim : dimNames) { + if (getInDimSize(dim) != getOutDimSize(dim)) { + return false; + } + } + // Once the inputs and output dimensions are the same, we can just check + // that the basis for the single remaining dimension is the identity. + sl = sl.flattenIns().flattenOuts(); int b = 0; const auto &inDimBases = sl.bases.begin()->second; for (auto basis : inDimBases) { @@ -922,118 +780,179 @@ LinearLayout LinearLayout::compose(const LinearLayout &outer) const { compositionIsSurjective); } -LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { - assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getOutDimNames()); - for (StringAttr outDim : getOutDimNames()) { - assert(getOutDimSize(outDim) <= outer.getOutDimSize(outDim)); +namespace { +std::unique_ptr concatMatrices(const LinearLayout &A, + const LinearLayout &B) { + // In plain words, "convert_layout does not change the shape of a tensor" + assert(A.getTotalOutDimSizeLog2() == B.getTotalOutDimSizeLog2() && + "Matrices must have the same number of output dimensions"); + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + + // rref expects the lower bits to be the lower indices of the matrix + auto concat = getMatrix(A); + auto BMat = getMatrix(B); + for (int r = 0; r < numRows; r++) { + concat[r] |= BMat[r] << numColsA; } - assert(outer.isSurjective()); + return concat; +} - // Make both `this` and `outer` injective. We need to do this on the - // `outer` layout because we can't invert a non-injective function. We - // choose to do so on the `this` layout as well. The rest of the comment - // explains why we make that choice. - // - // Recall from the header that C = A.invertAndCompose(B) just means that - // A(x) = B(C(x)). - // - // Sometimes we may have a choice of multiple values for a particular - // C(x). For example, if A(1) = B(0) = B(1) = 0, then C(1) can be either 0 - // or 1. - // - // We want to choose C such that C(x) != 0 where possible. For example, - // suppose we are transferring from registers to registers and we have the - // following layouts. - // - // A(thread=1, block=0) = 1 - // A(thread=2, block=0) = 2 - // A(thread=0, block=1) = 0 - // - // B(thread=1, block=0) = 2 - // B(thread=2, block=0) = 1 - // B(thread=0, block=1) = 0 - // - // Notice that A and B both have the same data in each of their two - // blocks. So if we want to transfer from A to B, we don't need to cross - // blocks, which is expensive. We want A.invertAndCompose(B) to reflect - // that choice. - // - // Let A' be A with the last line changed to "=4", and similarly for B'. - // When transferring from A' to B', we can't cross blocks even if we wanted - // to, because the two blocks now have different data. But also, any - // mapping of thread+block from A' to B' is also valid for mapping from A - // to B. - // - // Thus making A and B injective encodes our desire not to cross blocks, - // or more generally our desire that C(x) != 0 where possible. - auto [matThis, numRowsThis, numColsThis] = getInjectiveMat(*this); - auto [matOuter, numRowsOuter, numColsOuter] = getInjectiveMat( - outer.transposeOuts(llvm::to_vector(this->getOutDimNames()))); - - // Concatenate `matOuter` and `matThis` horizontally (i.e. `matThis` - // is to the right of `matOuter`). - int combinedNumRows = std::max(numRowsThis, numRowsOuter); - int combinedNumCols = numColsThis + numColsOuter; - assert(combinedNumCols <= 64 && "Can't handle huge layouts"); - - std::unique_ptr m(new uint64_t[combinedNumRows]()); - for (int r = 0; r < numRowsOuter; r++) { - m[r] = matOuter[r]; - } - for (int r = 0; r < numRowsThis; r++) { - m[r] |= matThis[r] << numColsOuter; - } - - // Perform Gaussian elimination on `m`. Because `outer` was modified to - // be bijective, the first half of the matrix should be the identity - // matrix. The remaining half are the bases for the combined - // transformation. - // - // `stride` is specified in number of 64-bit words per row, and we pack - // our matrix so that there's only one uint64_t per row. - f2reduce::inplace_rref_strided(m.get(), combinedNumRows, combinedNumCols, +LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { + // Solve the least square system AX = B for A = outer, B = *this + // and return the least square solution X of minimal norm + // A and B may not be surjective, but we assume that Im(B) \subset Im(A) + // Sketch of the algorithm: + // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + int numColsB = B.getTotalInDimSizeLog2(); + int numCols = numColsA + numColsB; + std::unique_ptr combinedMat = concatMatrices(A, B); + f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, /*stride=*/1); - // Check that the first half of the matrix is indeed the identity. - for (int r = 0; r < std::min(numRowsOuter, numColsOuter); r++) { - for (int c = 0; c < std::min(numColsOuter, numRowsOuter); c++) { - if (((m[r] >> c) & 1) != (r == c ? 1 : 0)) { - llvm::report_fatal_error("First half of the matrix was not the " - "identity, bug in invertAndCompose"); - } + // Compute the pivot columns + // Since A and B have the same image, each row will either have a pivot + // or will be all zeros + SmallVector pivotCols; + for (int r = 0; r < numRows; r++) { + auto row = combinedMat[r]; + if (row == 0) { + continue; } + int c = __builtin_ctzll(row); + assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); + assert(pivotCols.empty() || + pivotCols.back() < c && "Pivot columns are not in increasing order"); + pivotCols.push_back(c); + } + + // Extract A^{-1}B and complete the matrix using zeros + std::unique_ptr retMat(new uint64_t[numColsA]()); + int j = 0; + for (int r = 0; r < numColsA; r++) { + auto isPivot = j < pivotCols.size() && pivotCols[j] == r; + retMat[r] = isPivot ? combinedMat[j++] >> numColsA : 0; } // We need names for the in/out dim of the flattened layout we're going to // read off from `m`. These could be anything, doesn't matter. - StringAttr inDim1D = *getInDimNames().begin(); - StringAttr outDim1D = *getOutDimNames().begin(); + StringAttr inDim1D = *A.getInDimNames().begin(); + StringAttr outDim1D = *A.getOutDimNames().begin(); // Read off the new bases. These are for a flattened 1D -> 1D - // transformation from `this`'s in-dims to `outer`'s in-dims. - BasesT newBases; - auto &bs = newBases[inDim1D]; - for (int c = 0; c < numColsThis; c++) { + LinearLayout::BasesT retBases; + auto &bs = retBases[inDim1D]; + for (int c = 0; c < numColsB; c++) { int32_t basis = 0; - for (int r = 0; r < numRowsOuter; r++) { - basis |= (m[r] >> (numColsOuter + c) & 1) << r; + for (int r = 0; r < numColsA; r++) { + basis |= (retMat[r] >> c & 1) << r; } bs.push_back({basis}); } - LinearLayout flatComposed(std::move(newBases), - {{outDim1D, outer.getTotalInDimSize()}}, + LinearLayout retFlattened(std::move(retBases), + {{outDim1D, A.getTotalInDimSize()}}, /*requireSurjective=*/false); SmallVector> retInDims; SmallVector> retOutDims; - for (StringAttr dim : getInDimNames()) { - retInDims.push_back({dim, getInDimSize(dim)}); + for (StringAttr dim : B.getInDimNames()) { + retInDims.push_back({dim, B.getInDimSize(dim)}); } - for (StringAttr dim : outer.getInDimNames()) { - retOutDims.push_back({dim, outer.getInDimSize(dim)}); + for (StringAttr dim : A.getInDimNames()) { + retOutDims.push_back({dim, A.getInDimSize(dim)}); } - return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); + return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +} + +} // namespace + +LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { + // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` + // For this, we need to implement our LLVM lowerings by inverting the "outer" + // layout, and then iterating over the elements from the "this" layout and + // fetching the corresponding element from the "outer" layout. This exercises + // the broadcasting that we incentivise via choosing the minimum norm solution + // in lstsq. + + // The order of dims does not matter. We choose to transpose outer + auto outDims = llvm::to_vector(getOutDimNames()); + assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); + const auto &B = *this; + const auto A = outer.transposeOuts(outDims); + for (auto dim : outDims) { + assert(A.getOutDimSize(dim) == B.getOutDimSize(dim) && + "Convert layout does not change the shape of a tensor"); + } + + // We'll write A^{-1} to mean the inverse or the pseudo-inverse of A + // We are computing A^{-1}B so A must be surjective so that + // it has a left inverse. + assert(A.isSurjective()); + + // Broadcasting heuristic + // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` + // (broadcasting) on both layouts. We could map any warp to any warp in the + // conversion. Now, we want to map them as the identity map, to mark that + // nothing needs to be done there (`lstsq` would map all the warps to the + // zero warp, minimum norm solution). The heuristic here is as follows: + // - If a dimension is the same for both layouts, we want to map it as the + // identity + // Equivalently, we don't add it to the conversion + // - Otherwise, we just call lstsq (i.e. map all the equivalent elements + // to the same input element) to take advantage of broadcasting in shared + // memory and avoid saving repeated elements in shared memory + SmallVector identityDims; + for (auto dim : A.getInDimNames()) { + if (B.hasInDim(dim) && + A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { + identityDims.push_back(dim); + } + } + SmallVector ANonIdentityInDims; + SmallVector BNonIdentityInDims; + for (auto dim : A.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + ANonIdentityInDims.push_back(dim); + } + } + for (auto dim : B.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + BNonIdentityInDims.push_back(dim); + } + } + + auto AReduced = A.sublayout(ANonIdentityInDims, outDims); + auto BReduced = B.sublayout(BNonIdentityInDims, outDims); + + // If one is empty, the other must be empty as well + assert((AReduced == LinearLayout::empty()) == + (BReduced == LinearLayout::empty())); + bool isEmpty = AReduced == LinearLayout::empty(); + + auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); + + // TODO(Lezcano): We should return the reduced layout instead of re-adding the + // identity maps. With this, we'll be able to kill `minimalCvtLayout` + + // Add the identity maps for the dimensions that are the same for both layouts + for (auto dim : identityDims) { + ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); + } + + // Reshape the result + SmallVector> inDimsA; + SmallVector> inDimsB; + for (auto dim : A.getInDimNames()) { + inDimsA.push_back({dim, A.getInDimSize(dim)}); + } + for (auto dim : B.getInDimNames()) { + inDimsB.push_back({dim, B.getInDimSize(dim)}); + } + ret = ret.reshapeIns(inDimsB).reshapeOuts(inDimsA); + return ret; } llvm::MapVector @@ -1071,6 +990,30 @@ LinearLayout::getFreeVariableMasks() const { return ret; } +size_t hash_value(const LinearLayout &layout) { + size_t seed = 0; + + // Hash the bases + for (const auto &base : layout.getBases()) { + // Hash the input dimension name + seed = llvm::hash_combine(seed, base.first); + + // Hash the vectors in bases + for (const auto &vec : base.second) { + for (int32_t val : vec) { + seed = llvm::hash_combine(seed, val); + } + } + } + + // Hash the output dimensions and their sizes + for (const auto &outDim : layout.getOutDimNames()) { + seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); + } + // Don't hash the surjective flag as it's a cached property + return seed; +} + bool operator==(LinearLayout lhs, LinearLayout rhs) { if (!lhs.equalIgnoringOutDimSizes(rhs)) return false; diff --git a/prepare.sh b/prepare.sh new file mode 100755 index 000000000000..e9a269a795b2 --- /dev/null +++ b/prepare.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# saving path +HERE=$PWD + +echo "===================================== Configuring Checkout" +git submodule init +git submodule update + +echo "===================================== Setting Up Conda Env" +export CONDA_INSTALL_DIR=$HERE/../miniforge +export ENV_NAME=triton +if [ ! -d $CONDA_INSTALL_DIR ]; then + pushd ./../ + wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-$(uname -m).sh" + bash ./Miniforge3-Linux-$(uname -m).sh -b -p ${CONDA_INSTALL_DIR} + ${CONDA_INSTALL_DIR}/bin/conda create -y -n ${ENV_NAME} python=3.9 + source ${CONDA_INSTALL_DIR}/bin/activate ${ENV_NAME} + + echo "===================================== Install Dependencies" + pip install ninja cmake wheel pybind11 scipy numpy torch pytest lit pandas matplotlib + if [ $? != 0 ]; then + exit 1 + fi + + popd + conda deactivate +else + echo "Miniconda already installed, skipping." +fi + +echo "===================================== Building LLVM" +if [ ! -d $HERE/../llvm-project ]; then + pushd ./../ + git clone https://github.com/llvm/llvm-project.git + cd llvm-project + git checkout `cat ${HERE}/cmake/llvm-hash.txt` + mkdir -p build + pushd build + cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=True -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_USE_LINKER=lld -DLLVM_ENABLE_PROJECTS="mlir;llvm" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" ../llvm + ninja + popd + popd +else + echo "LLVM already built, skipping." +fi diff --git a/python/setup.py b/python/setup.py index 714668462f0e..5f907472770c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,7 +14,7 @@ from io import BytesIO from distutils.command.clean import clean from pathlib import Path -from typing import List, NamedTuple, Optional +from typing import List, Optional from setuptools import Extension, setup from setuptools.command.build_ext import build_ext @@ -148,13 +148,15 @@ def is_offline_build() -> bool: # --- third party packages ----- -class Package(NamedTuple): +@dataclass +class Package: package: str name: str url: str include_flag: str lib_flag: str syspath_var_name: str + sym_name: Optional[str] = None # json @@ -207,8 +209,10 @@ def get_llvm_package_info(): with open(llvm_hash_path, "r") as llvm_hash_file: rev = llvm_hash_file.read(8) name = f"llvm-{rev}-{system_suffix}" + # Create a stable symlink that doesn't include revision + sym_name = f"llvm-{system_suffix}" url = f"https://oaitriton.blob.core.windows.net/public/llvm-builds/{name}.tar.gz" - return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") + return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH", sym_name=sym_name) def open_url(url): @@ -233,6 +237,20 @@ def get_triton_cache_path(): return os.path.join(user_home, ".triton") +def update_symlink(link_path, source_path): + source_path = Path(source_path) + link_path = Path(link_path) + + if link_path.is_symlink(): + link_path.unlink() + elif link_path.exists(): + shutil.rmtree(link_path) + + print(f"creating symlink: {link_path} -> {source_path}", file=sys.stderr) + link_path.absolute().parent.mkdir(parents=True, exist_ok=True) # Ensure link's parent directory exists + link_path.symlink_to(source_path, target_is_directory=True) + + def get_thirdparty_packages(packages: list): triton_cache_path = get_triton_cache_path() thirdparty_cmake_args = [] @@ -269,6 +287,10 @@ def get_thirdparty_packages(packages: list): thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include") if p.lib_flag: thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib") + if p.sym_name is not None: + sym_link_path = os.path.join(package_root_dir, p.sym_name) + update_symlink(sym_link_path, package_dir) + return thirdparty_cmake_args @@ -379,7 +401,7 @@ def get_pybind11_cmake_args(self): pybind11_include_dir = os.path.join(pybind11_sys_path, "include") else: pybind11_include_dir = pybind11.get_include() - return [f"-DPYBIND11_INCLUDE_DIR={pybind11_include_dir}"] + return [f"-Dpybind11_INCLUDE_DIR='{pybind11_include_dir}'", f"-Dpybind11_DIR='{pybind11.get_cmake_dir()}'"] def get_proton_cmake_args(self): cmake_args = get_thirdparty_packages([get_json_package_info()]) @@ -417,7 +439,7 @@ def build_extension(self, ext): "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, - "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DPython3_INCLUDE_DIR=" + python_include_dir, "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] @@ -429,12 +451,10 @@ def build_extension(self, ext): cfg = get_build_type() build_args = ["--config", cfg] + cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"] if platform.system() == "Windows": cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2**32: - cmake_args += ["-A", "x64"] else: - cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) build_args += ['-j' + max_jobs] @@ -462,15 +482,17 @@ def build_extension(self, ext): "-DCMAKE_CXX_FLAGS=-fsanitize=address", ] - if check_env_flag("TRITON_BUILD_WITH_CCACHE"): - cmake_args += [ - "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", - ] + # environment variables we will pass through to cmake + passthrough_args = [ + "TRITON_BUILD_PROTON", + "TRITON_BUILD_TUTORIALS", + "TRITON_BUILD_WITH_CCACHE", + "TRITON_PARALLEL_LINK_JOBS", + ] + cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ] if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON cmake_args += self.get_proton_cmake_args() - else: - cmake_args += ["-DTRITON_BUILD_PROTON=OFF"] if is_offline_build(): # unit test builds fetch googletests from GitHub @@ -499,8 +521,9 @@ def get_platform_dependent_src_path(subdir): if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.'))) +exe_extension = sysconfig.get_config_var("EXE") download_and_copy( - name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", + name="ptxas", src_path=f"bin/ptxas{exe_extension}", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2" @@ -509,7 +532,7 @@ def get_platform_dependent_src_path(subdir): (*version.split('.')))) download_and_copy( name="cuobjdump", - src_path="bin/cuobjdump", + src_path=f"bin/cuobjdump{exe_extension}", dst_path="bin/cuobjdump", variable="TRITON_CUOBJDUMP_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"], @@ -518,7 +541,7 @@ def get_platform_dependent_src_path(subdir): ) download_and_copy( name="nvdisasm", - src_path="bin/nvdisasm", + src_path=f"bin/nvdisasm{exe_extension}", dst_path="bin/nvdisasm", variable="TRITON_NVDISASM_PATH", version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"], @@ -559,16 +582,12 @@ def get_platform_dependent_src_path(subdir): f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2") (*version.split('.')))) -backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] +backends = [*BackendInstaller.copy(["nvidia", "amd", "cpu"]), *BackendInstaller.copy_externals()] def add_link_to_backends(): for backend in backends: - if os.path.islink(backend.install_dir): - os.unlink(backend.install_dir) - if os.path.exists(backend.install_dir): - shutil.rmtree(backend.install_dir) - os.symlink(backend.backend_dir, backend.install_dir) + update_symlink(backend.install_dir, backend.backend_dir) if backend.language_dir: # Link the contents of each backend's `language` directory into @@ -577,21 +596,13 @@ def add_link_to_backends(): for x in os.listdir(backend.language_dir): src_dir = os.path.join(backend.language_dir, x) install_dir = os.path.join(extra_dir, x) - if os.path.islink(install_dir): - os.unlink(install_dir) - if os.path.exists(install_dir): - shutil.rmtree(install_dir) - os.symlink(src_dir, install_dir) + update_symlink(install_dir, src_dir) def add_link_to_proton(): proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton")) proton_install_dir = os.path.join(os.path.dirname(__file__), "triton", "profiler") - if os.path.islink(proton_install_dir): - os.unlink(proton_install_dir) - if os.path.exists(proton_install_dir): - shutil.rmtree(proton_install_dir) - os.symlink(proton_dir, proton_install_dir) + update_symlink(proton_install_dir, proton_dir) def add_links(): @@ -663,6 +674,7 @@ def get_packages(): "triton/compiler", "triton/language", "triton/language/extra", + "triton/language/extra/cpu", "triton/runtime", "triton/backends", "triton/tools", @@ -695,11 +707,12 @@ def get_git_commit_hash(length=8): setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), - version="3.0.0" + get_git_commit_hash() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), + version="3.2.0" + get_git_commit_hash() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), author="Philippe Tillet", author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", + install_requires=["setuptools>=40.8.0"], packages=get_packages(), entry_points=get_entry_points(), package_data=package_data, @@ -739,10 +752,12 @@ def get_git_commit_hash(length=8): "autopep8", "flake8", "isort", - "numpy", + "numpy<2.0.0", "pytest", + "pytest-forked", + "pytest-xdist", "scipy>=1.7.1", - "llnl-hatchet", + # "llnl-hatchet", # TODO: Re-enable this, not available on macos-arm64 ], "tutorials": [ "matplotlib", diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 6ab7c6c75c70..747a0cc17191 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -1,29 +1,44 @@ +#include #include #include #include +#include #include #include +#include #include namespace py = pybind11; namespace { +struct npy_half { + uint16_t value; +}; + enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; +std::mutex atomic_op_guard; + +template +constexpr bool is_reinterpret_cast_to_atomic_safe = + std::is_trivially_copyable_v && + std::is_trivially_copyable_v> && + std::is_standard_layout_v && std::is_standard_layout_v> && + sizeof(T) == sizeof(std::atomic) && + alignof(T) == alignof(std::atomic); + enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; -std::map mem_semantic_map = { - {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, - {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, - {MemSemantic::RELEASE, __ATOMIC_RELEASE}, - {MemSemantic::RELAXED, __ATOMIC_RELAXED}, +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel}, + {MemSemantic::ACQUIRE, std::memory_order_acquire}, + {MemSemantic::RELEASE, std::memory_order_release}, + {MemSemantic::RELAXED, std::memory_order_relaxed}, }; -// Use compiler builtin atomics instead of std::atomic which requires -// each variable to be declared as atomic. -// Currently work for clang and gcc. -template T atomic_cmp(T *ptr, T val, int order) { +template +T atomic_cmp(T *ptr, T val, std::memory_order order) { auto cmp = [](T old, T val) { if constexpr (is_min) { return old > val; @@ -31,43 +46,256 @@ template T atomic_cmp(T *ptr, T val, int order) { return old < val; } }; - // First load - T old_val = __atomic_load_n(ptr, order); - while (cmp(old_val, val)) { - if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { - break; + + T old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_ptr = reinterpret_cast *>(ptr); + old_val = atomic_ptr->load(order); + while (cmp(old_val, val)) { + if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) { + break; + } + } + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *ptr; + if (cmp(old_val, val)) { + *ptr = val; } } return old_val; } -template T atomic_fadd(T *ptr, T val, int order) { - T old_val; - T new_val; - // First load - // Load ptr as if uint32_t or uint64_t and then memcpy to T - if constexpr (sizeof(T) == 4) { - uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); - } else if constexpr (sizeof(T) == 8) { - uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); +template T atomic_fadd(T *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + T old_value; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + T new_value; + std::atomic *atomic_loc = reinterpret_cast *>(loc); + old_value = atomic_loc->load(order); + do { + new_value = old_value + value; + } while ( + !atomic_loc->compare_exchange_weak(old_value, new_value, order, order)); } else { - throw std::invalid_argument("Unsupported data type"); + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = old_value + value; + } + + return old_value; +} + +/** Create a value of type `To` from the bits of `from`. + * + * similar to `std::bit_cast` but compatible with C++17, + * should perform similar to `*reinterpret_cast(&from)` + * or through punning without expecting any undefined behaviors. + * + * Note: taken from + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 + * with simplification. + */ +template +inline To BitCast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), + "both data types must have the same size"); + + static_assert(std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + "both data types must be trivially copyable"); + + To to; + memcpy(&to, &from, sizeof(from)); + return to; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 +template +inline uint16_t FromFloatBits(uint32_t f) { + uint32_t f_exp, f_sig; + uint16_t h_sgn, h_exp, h_sig; + + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); + f_exp = (f & 0x7f800000u); + + /* Exponent overflow/NaN converts to signed inf/NaN */ + if (f_exp >= 0x47800000u) { + if (f_exp == 0x7f800000u) { + /* Inf or NaN */ + f_sig = (f & 0x007fffffu); + if (f_sig != 0) { + /* NaN - propagate the flag in the significand... */ + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); + /* ...but make sure it stays a NaN */ + if (ret == 0x7c00u) { + ret++; + } + return h_sgn + ret; + } else { + /* signed inf */ + return (uint16_t)(h_sgn + 0x7c00u); + } + } else { + if constexpr (gen_overflow) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error("overflow to signed inf"); + } + return (uint16_t)(h_sgn + 0x7c00u); + } } - while (true) { - new_val = old_val + val; - if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, - order)) { - break; + + /* Exponent underflow converts to a subnormal half or signed zero */ + if (f_exp <= 0x38000000u) { + /* + * Signed zeros, subnormal floats, and floats with small + * exponents all convert to signed zero half-floats. + */ + if (f_exp < 0x33000000u) { + if constexpr (gen_underflow) { + /* If f != 0, it underflowed to 0 */ + if ((f & 0x7fffffff) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + return h_sgn; + } + /* Make the subnormal significand */ + f_exp >>= 23; + f_sig = (0x00800000u + (f & 0x007fffffu)); + if constexpr (gen_underflow) { + /* If it's not exactly represented, it underflowed */ + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } } + /* + * Usually the significand is shifted by 13. For subnormals an + * additional shift needs to occur. This shift is one for the largest + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which + * offsets the new first bit. At most the shift can be 1+10 bits. + */ + f_sig >>= (113 - f_exp); + /* Handle rounding by adding 1 to the bit beyond half precision */ + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. However, the (113 - f_exp) + * shift can lose up to 11 bits, so the || checks them in the original. + * In all other cases, we can just add one. + */ + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp from zero to one and h_sig will be zero. + * This is the correct result. + */ + return (uint16_t)(h_sgn + h_sig); + } + + /* Regular case with no overflow or underflow */ + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); + /* Handle rounding by adding 1 to the bit beyond half precision */ + f_sig = (f & 0x007fffffu); + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. In all other cases, we do. + */ + if ((f_sig & 0x00003fffu) != 0x00001000u) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp by one and h_sig will be zero. This is the + * correct result. h_exp may increment to 15, at greatest, in + * which case the result overflows to a signed inf. + */ + if constexpr (gen_overflow) { + h_sig += h_exp; + if (h_sig == 0x7c00u) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error(""); + } + return h_sgn + h_sig; + } else { + return h_sgn + h_exp + h_sig; } - return old_val; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 +constexpr uint32_t ToFloatBits(uint16_t h) { + uint16_t h_exp = (h & 0x7c00u); + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: { // 0 or subnormal + uint16_t h_sig = (h & 0x03ffu); + // Signed zero + if (h_sig == 0) { + return f_sgn; + } + // Subnormal + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + } + case 0x7c00u: // inf or NaN + // All-ones exponent and a copy of the significand + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); + default: // normalized + // Just need to adjust the exponent and shift + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); + } +} + +npy_half npy_float_to_half(float f) { + return {FromFloatBits(BitCast(f))}; +} + +float npy_half_to_float(npy_half h) { + return BitCast(ToFloatBits(h.value)); +} + +template <> +npy_half atomic_fadd(npy_half *loc, npy_half value, + std::memory_order order) { + npy_half old_value; + + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = npy_float_to_half(npy_half_to_float(old_value) + + npy_half_to_float(value)); + + return old_value; } class AtomicOp { public: - AtomicOp(const uint64_t *ptr, size_t numel, int order) + AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) : ptr(ptr), numel(numel), order(order) {} void apply() { @@ -83,25 +311,26 @@ class AtomicOp { const uint64_t *ptr; size_t numel; - int order; + std::memory_order order; }; template class AtomicRMWOpBase : public AtomicOp { public: AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, - const bool *mask, size_t numel, int order) + const bool *mask, size_t numel, std::memory_order order) : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { + DType *ptr = static_cast(loc); *(static_cast(ret) + i) = - applyAtMasked(static_cast(loc), - *(static_cast(val) + i), order); + applyAtMasked(ptr, *(static_cast(val) + i), order); } } - virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + virtual DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) = 0; const void *val; void *ret; @@ -121,8 +350,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_add(loc, value, order); + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc + value; + } + return old_val; } }; @@ -133,7 +373,8 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { return atomic_fadd(loc, value, order); } }; @@ -145,8 +386,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_and(loc, value, order); + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc & value; + } + return old_val; } }; @@ -157,8 +409,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_or(loc, value, order); + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc | value; + } + return old_val; } }; @@ -169,8 +432,19 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_xor(loc, value, order); + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc ^ value; + } + return old_val; } }; @@ -182,7 +456,8 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { return atomic_cmp(loc, value, order); } }; @@ -195,7 +470,8 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { return atomic_cmp(loc, value, order); } }; @@ -207,15 +483,48 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_exchange_n(loc, value, order); + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = atomic_loc->exchange(value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = value; + } + return old_val; } }; +template +void atomic_compare_exchange_strong(void *loc, void *expected, + const void *desired, size_t i, + std::memory_order order) { + T desired_val = *(static_cast(desired) + i); + T *expected_uint = static_cast(expected) + i; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); + atomic_loc->compare_exchange_strong(*expected_uint, desired_val, order, + order); + } else { + const std::lock_guard lock(atomic_op_guard); + T *atomic_loc = static_cast(loc); + if (*atomic_loc == *expected_uint) { + *atomic_loc = desired_val; + } else { + *expected_uint = *atomic_loc; + } + } +} + class AtomicCASOp : public AtomicOp { public: AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, - size_t itemsize, size_t numel, int order) + size_t itemsize, size_t numel, std::memory_order order) : AtomicOp(ptr, numel, order), expected(expected), desired(desired), itemsize(itemsize) {} @@ -224,31 +533,17 @@ class AtomicCASOp : public AtomicOp { // Atomic operations perform bitwise comparison, so it's safe to // use number of bytes (itemsize) to determine the type of pointers if (itemsize == 1) { - uint8_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + atomic_compare_exchange_strong(loc, expected, desired, i, order); } else if (itemsize == 2) { - uint16_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + atomic_compare_exchange_strong(loc, expected, desired, i, + order); } else if (itemsize == 4) { - uint32_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + atomic_compare_exchange_strong(loc, expected, desired, i, + order); } else if (itemsize == 8) { - uint64_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + atomic_compare_exchange_strong(loc, expected, desired, i, + order); } else { - // The ‘__atomic’ builtins can be used with any integral scalar or pointer - // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are - // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the - // architecture. - // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html throw std::invalid_argument("Invalid byte size"); } } @@ -274,7 +569,7 @@ template struct OpCreator { void *ret; const bool *mask; size_t numel; - int order; + std::memory_order order; std::unique_ptr &atomic_op; template void create() { @@ -285,10 +580,20 @@ template struct OpCreator { } }; +template <> template <> void OpCreator::create() { + if (!atomic_op && dtype.char_() == 'e') { // float16 + // workaround until https://github.com/pybind/pybind11/issues/4061 is + // implemented + atomic_op = std::make_unique>( + ptr, val, ret, mask, numel, order); + } +}; + template std::unique_ptr makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, - void *ret, const bool *mask, size_t numel, int order) { + void *ret, const bool *mask, size_t numel, + std::memory_order order) { // Iterate over all supported data types, make one that matches, and return std::unique_ptr atomic_op; OpCreator try_make_op{dtype, ptr, val, ret, @@ -366,7 +671,7 @@ void init_triton_interpreter(py::module &&m) { m.def("atomic_rmw", [](RMWOp rmw_op, py::array_t ptr, py::array val, py::array_t mask, MemSemantic sem) -> py::array { - int order = mem_semantic_map[sem]; + std::memory_order order = mem_semantic_map[sem]; int numel = ptr.size(); auto shape = std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); @@ -390,7 +695,7 @@ void init_triton_interpreter(py::module &&m) { switch (rmw_op) { MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) - MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) @@ -413,7 +718,7 @@ void init_triton_interpreter(py::module &&m) { m.def("atomic_cas", [](py::array_t ptr, py::array &cmp, py::array &val, MemSemantic sem) -> py::array { - int order = mem_semantic_map[sem]; + std::memory_order order = mem_semantic_map[sem]; int numel = ptr.size(); auto shape = std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); diff --git a/python/src/ir.cc b/python/src/ir.cc index 9945c6188294..3fb88359b105 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -6,7 +6,6 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/UB/IR/UBOps.h" @@ -24,6 +23,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/LocationSnapshot.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -46,6 +46,7 @@ class TritonOpBuilder { } OpBuilder &getBuilder() { return *builder; } + MLIRContext *getContext() { return builder->getContext(); } bool isLineInfoEnabled() { return lineInfoEnabled; } @@ -205,12 +206,13 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::enum_(m, "F8F6F4TY", py::module_local()) - .value("E4M3", F8F6F4Type::E4M3) - .value("E5M2", F8F6F4Type::E5M2) - .value("E2M3", F8F6F4Type::E2M3) - .value("E3M2", F8F6F4Type::E3M2) - .value("E2M1", F8F6F4Type::E2M1) + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) .export_values(); py::class_(m, "context", py::module_local()) @@ -231,10 +233,9 @@ void init_triton_ir(py::module &&m) { m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; registry.insert(); + math::MathDialect, arith::ArithDialect, scf::SCFDialect, + ::mlir::gpu::GPUDialect, cf::ControlFlowDialect, + LLVM::LLVMDialect, mlir::ub::UBDialect>(); mlir::LLVM::registerInlinerInterface(registry); registerBuiltinDialectTranslation(registry); registerLLVMDialectTranslation(registry); @@ -247,6 +248,16 @@ void init_triton_ir(py::module &&m) { .def("is_integer", [](Type &self, unsigned width) { return self.isInteger(width); }) .def("is_fp16", &Type::isF16) + .def("__eq__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty != nullptr) && (*other_ty == self); + }) + .def("__ne__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty == nullptr) || (*other_ty != self); + }) .def("__str__", [](Type &self) { std::string str; llvm::raw_string_ostream os(str); @@ -490,6 +501,16 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, FuncOp &funcOp) -> void { self.push_back(funcOp); }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (LLVM::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) .def("has_function", [](ModuleOp &self, std::string &funcName) -> bool { if (self.lookupSymbol(funcName)) @@ -500,6 +521,43 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, std::string &funcName) -> FuncOp { return self.lookupSymbol(funcName); }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) .def("get_int_attr", [](ModuleOp &self, std::string name) -> py::object { auto ret = self->getAttrOfType(name); @@ -1259,19 +1317,26 @@ void init_triton_ir(py::module &&m) { self.create(ptrs, val, mask, cacheModifier, evictionPolicy); }) + .def("create_reinterpret_tensor_descriptor", + [](TritonOpBuilder &self, Value desc_ptr, Type blockTy) -> Value { + auto ctx = self.getContext(); + auto resultTy = triton::TensorDescType::get( + ctx, cast(blockTy)); + return self.create(resultTy, desc_ptr); + }) .def("create_descriptor_load", - [](TritonOpBuilder &self, Value desc_ptr, - std::vector &indices, Type type, + [](TritonOpBuilder &self, Value desc, std::vector &indices, CacheModifier cacheModifier, EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getBlockType(); return self.create( - type, desc_ptr, indices, cacheModifier, evictionPolicy); + resTy, desc, indices, cacheModifier, evictionPolicy); }) .def("create_descriptor_store", - [](TritonOpBuilder &self, Value desc_ptr, Value value, + [](TritonOpBuilder &self, Value desc, Value value, std::vector &indices) -> void { - self.create(desc_ptr, value, - indices); + self.create(desc, value, indices); }) .def("create_tensormap_create", [](TritonOpBuilder &self, Value desc_ptr, Value global_address, @@ -1422,12 +1487,13 @@ void init_triton_ir(py::module &&m) { maxNumImpreciseAcc); }) .def("create_dot_scaled", - [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, - F8F6F4Type lhs_format, mlir::Value &rhs, - std::optional &rhs_scale, F8F6F4Type rhs_format, - mlir::Value &c) -> mlir::Value { + [](TritonOpBuilder &self, mlir::Value &lhs, + std::optional &lhs_scale, + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value { return self.create( - c.getType(), lhs, rhs, c, lhs_scale, + c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()), rhs_scale.value_or(Value()), lhs_format, rhs_format); }) .def("create_floor", @@ -1446,22 +1512,74 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_expm1", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_cos", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_cosh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_sin", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_sinh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tan", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_acos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_acosh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_asin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_asinh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_atan", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_atanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_log", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_log1p", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_log2", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_log10", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_erf", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1482,6 +1600,14 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_cbrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_trunc", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_reduce", [](TritonOpBuilder &self, std::vector operands, int axis) -> OpState { return self.create(operands, axis); }) @@ -1559,6 +1685,9 @@ void init_triton_ir(py::module &&m) { IntegerType::get(operand.getContext(), 32)), operand); }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) // Force GPU barrier .def("create_barrier", [](TritonOpBuilder &self) { self.create(); }) @@ -1576,6 +1705,14 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &ptr, std::vector &offsets) -> Value { return self.create(ptr.getType(), ptr, offsets); + }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, + std::vector &tensorShape) -> Value { + return self.create(base, shape, strides, + tensorShape); }); py::class_(m, "pass_manager", py::module_local()) @@ -1639,7 +1776,14 @@ void init_triton_ir(py::module &&m) { auto anchorName = self.getOpAnchorName(); auto passes = self.getPasses(); Operation *op = mod.getOperation(); + // Save a reproducer for the current pass manager invocation + // immediately. makeReproducer(anchorName, passes, op, reproducerPath); + // But if the pass manager crashes, attempt to generate a local + // reproducer instead. + mod.getContext()->disableMultithreading(); + self.enableCrashReproducerGeneration(reproducerPath, + /*genLocalReproducer=*/true); } if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { @@ -1663,6 +1807,8 @@ void init_triton_ir(py::module &&m) { }); ::llvm::DebugFlag = true; + // For release build setCurrentDebugTypes is a macro, so avoid + // namespace prefix using namespace llvm; setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } @@ -1672,6 +1818,25 @@ void init_triton_ir(py::module &&m) { self.enableTiming(); } + // Run the pass manager under a source manager diagnostic handler, which + // enables emitted MLIR diagnostics to directly reference Python source + // code. This diagnostic handler will only filter for errors. + struct SourceMgrErrorDiagnosticHandler + : public SourceMgrDiagnosticHandler { + SourceMgrErrorDiagnosticHandler(MLIRContext *ctx) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this](Diagnostic &diag) { + if (diag.getSeverity() != DiagnosticSeverity::Error) + return failure(); + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; + }; + SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext()); + if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); }); diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f9b98a2540a2..e4be9846bcc4 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -1,8 +1,10 @@ -#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -21,6 +23,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include @@ -40,19 +43,46 @@ struct BreakStructPhiNodesPass : PassInfoMixin { using namespace llvm; +std::string getDefaultTargerOrProcessTriple() { + // Return process triple iff the default target triple is empty. + std::string triple = llvm::sys::getDefaultTargetTriple(); + if (triple.empty()) { + // host + triple = llvm::sys::getProcessTriple(); + } + return triple; +} + std::unique_ptr createTargetMachine(llvm::Module *module, std::string proc, - bool enable_fp_fusion, const std::string &features) { + bool enable_fp_fusion, const std::string &features, + bool enable_fast_math = false) { + auto triple = getDefaultTargerOrProcessTriple(); + module->setTargetTriple(triple); std::string error; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } llvm::TargetOptions opt; bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); if (enable_fp_fusion) opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; + + if (enable_fast_math) { + opt.UnsafeFPMath = true; + opt.NoInfsFPMath = true; + opt.NoNaNsFPMath = true; + opt.NoTrappingFPMath = true; + opt.NoSignedZerosFPMath = true; + opt.ApproxFuncFPMath = true; + } else { + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + } + opt.TrapUnreachable = true; opt.MCOptions.AsmVerbose = true; opt.MCOptions.PreserveAsmComments = true; @@ -64,12 +94,10 @@ createTargetMachine(llvm::Module *module, std::string proc, return machine; } -std::string translateLLVMIRToASM(llvm::Module &module, - const std::string &triple, - const std::string &proc, - const std::string &features, - const std::vector &flags, - bool enable_fp_fusion, bool isObject) { +std::string translateLLVMIRToASM( + llvm::Module &module, const std::string &triple, const std::string &proc, + const std::string &features, const std::vector &flags, + bool enable_fp_fusion, bool isObject, bool enable_fast_math = false) { using namespace mlir; // options auto options = llvm::cl::getRegisteredOptions(); @@ -131,7 +159,8 @@ std::string translateLLVMIRToASM(llvm::Module &module, // create machine module.setTargetTriple(triple); - auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features, + enable_fast_math); // set data layout module.setDataLayout(machine->createDataLayout()); // emit machine code @@ -139,8 +168,6 @@ std::string translateLLVMIRToASM(llvm::Module &module, { llvm::raw_string_ostream stream(result); llvm::buffer_ostream pstream(stream); - for (llvm::Function &f : module.functions()) - f.addFnAttr(llvm::Attribute::AlwaysInline); llvm::legacy::PassManager pass; // emit auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile @@ -390,6 +417,76 @@ void init_triton_llvm(py::module &&m) { py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false); + m.def("set_host_target", [](llvm::Module *mod) { + auto triple = getDefaultTargerOrProcessTriple(); + mod->setTargetTriple(triple); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + std::unique_ptr machine{target->createTargetMachine( + mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, + llvm::Reloc::PIC_)}; + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "translate_to_host_asm", + [](std::string llvmIR, bool enable_fp_fusion, + bool enable_fast_math) -> py::object { + std::string res; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + auto triple = getDefaultTargerOrProcessTriple(); + res = translateLLVMIRToASM(*module, triple, + llvm::sys::getHostCPUName().str(), "", {}, + enable_fp_fusion, false, enable_fast_math); + } + return py::str(res); + }, + ret::take_ownership); + + m.def( + "translate_to_bc", + [](const std::string llvmIR) -> py::object { + py::gil_scoped_release allow_threads; + // create LLVM module + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + // Write bitcode to a buffer. + llvm::SmallVector buf; + llvm::BitcodeWriter writer(buf); + writer.writeModule(*module); + writer.writeStrtab(); + std::string bitcode(buf.begin(), buf.end()); + return py::bytes(bitcode); + }, + ret::take_ownership); + m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, @@ -470,6 +567,39 @@ void init_triton_llvm(py::module &&m) { } } }); + + m.def("get_cpu_tripple", []() { return llvm::sys::getProcessTriple(); }); + + m.def("get_cpu_name", []() { return llvm::sys::getHostCPUName().str(); }); + + m.def("get_cpu_features", []() { + auto features = llvm::sys::getHostCPUFeatures(); + + std::set res; + for (auto &f : features) { + if (f.second) + res.insert(f.first().str()); + } + + // Likely something went wrong with the LLVM feature detection. + if (!res.size()) { + std::string triple = llvm::sys::getProcessTriple(); + // e.g. arm64-apple-darwin24.1.0 + // ^^^^^ + std::size_t pos = triple.find('-'); + if (pos == std::string::npos) { + return res; + } + + std::string arch = triple.substr(0, pos); + if (arch == "aarch64" || arch == "arm64") { + // Safe because NEON is a mandatory feature for aarch64. + res.insert("neon"); // For math tests + } + } + + return res; + }); } void triton_stacktrace_signal_handler(void *) { diff --git a/python/src/passes.cc b/python/src/passes.cc index 98d8369d40aa..c4ff7bb47117 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -64,18 +64,27 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory", + createTritonGPUGlobalScratchAllocationPass); ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", createTritonGPUCombineTensorSelectAndIf); ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", + createTritonGPULoopScheduling, int); + ADD_PASS_WRAPPER_0("add_coalesce_async_copy", + createTritonGPUCoalesceAsyncCopy); } +void init_triton_passes_ttcpuir(py::module &&m) {} + void init_triton_passes_convert(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { @@ -88,6 +97,7 @@ void init_triton_passes(py::module &&m) { init_triton_passes_common(m.def_submodule("common")); init_triton_passes_convert(m.def_submodule("convert")); init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttcpuir(m.def_submodule("ttcpuir")); init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); init_triton_passes_llvmir(m.def_submodule("llvmir")); } diff --git a/python/src/passes.h b/python/src/passes.h index 46801d802c75..629fe362d8b2 100644 --- a/python/src/passes.h +++ b/python/src/passes.h @@ -34,7 +34,5 @@ }) #define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ - m.def(name, \ - [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ - pm.addPass(builder({val0, val1, val2, val3})); \ - }) + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); }) diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 67c216b4bc08..544c745da49c 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -11,8 +11,22 @@ import triton import triton.language as tl +from triton._internal_testing import is_hip_mi300, is_cuda, is_hip + +input_dtypes = ["bfloat16", "float16", "float32", "float64"] +if is_cuda(): + input_dtypes += ["int8", "float8_e5m2"] + cc = torch.cuda.get_device_capability(0) + if cc >= (8, 9): + input_dtypes += ["float8_e4m3fn"] +elif is_hip_mi300(): + input_dtypes += [ + "int8", + "float8_e5m2", + # natively supported on mi300 (see CDNA3 ISA, section 7.2) + "float8_e4m3fnuz", + ] -input_dtypes = ["float16", "float32", "float64"] out_dtypes = ["float16", "float32"] @@ -63,28 +77,43 @@ def matmul_kernel(A, B, C, M, N, K, # tl.store(C, acc, mask=mask) -@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", - [(M, K, N, w, x, o) # - for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # +@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype", + [(M, K, N, BLOCK_K, BLOCK_M, w, x, o) # + for BLOCK_K in [16, 32] # + for BLOCK_M in [16, 64] # + for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] # for w in input_dtypes for x in input_dtypes # for o in out_dtypes]) -def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): +def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype): if x_dtype == w_dtype: pytest.skip("skip the same input dtype") + if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]: + pytest.skip("skip due to bug on HIP path") device = torch.cuda.current_device() - x_dtype = getattr(torch, x_dtype) - w_dtype = getattr(torch, w_dtype) - a = torch.randn((M, K), device=device, dtype=x_dtype) - b = torch.randn((K, N), device=device, dtype=w_dtype) + x_dtype: torch.dtype = getattr(torch, x_dtype) + w_dtype: torch.dtype = getattr(torch, w_dtype) + + def init_tensor(dtype, shape): + if dtype == torch.int8: + return torch.randint(0, 2, shape, device=device, dtype=dtype) + elif dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2): + return torch.randn(shape, device=device, dtype=torch.float16).to(dtype) + else: + return torch.randn(shape, device=device, dtype=dtype) + + torch.manual_seed(42) + a = init_tensor(w_dtype, (M, K)) + b = init_tensor(x_dtype, (K, N)) + torch_dtype = getattr(torch, out_dtype) triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype)) out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) # launch kernel - BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 - grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1) + block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K + grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1) matmul_kernel[grid]( a, b, out_triton, M, N, K, # @@ -92,8 +121,8 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): b.stride(0), b.stride(1), # out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # GROUP_M=8, # - BLOCK_M=BLOCK_M, # - BLOCK_N=BLOCK_N, # - BLOCK_K=BLOCK_K) + BLOCK_M=block_m, # + BLOCK_N=block_n, # + BLOCK_K=block_k) torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01) diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 82298c41c710..b6143b17871f 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -239,3 +239,40 @@ def kernel(in_ptr, out_ptr): kernel[(1, )](data, res) ref = torch.flip(data[1:513], [0]) assert (res == ref).all() + + +@triton.jit +def _triton_cummax_helper_fn(arg0_0, arg0_1, arg1_0, arg1_1): + tmp0 = arg0_0 > arg1_0 + tmp1 = arg0_0 == arg1_0 + tmp2 = arg0_1 > arg1_1 + tmp3 = tmp1 & tmp2 + tmp4 = tmp0 | tmp3 + tmp5 = tl.where(tmp4, arg0_0, arg1_0) + tmp6 = tl.where(tmp4, arg0_1, arg1_1) + return tmp5, tmp6 + + +def test_inductor_cummax_bool(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, out_ptr1, XBLOCK: tl.constexpr): + offset = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + offset).to(tl.int1) + tmp1 = tmp0.to(tl.int1) + tmp3 = offset.to(tl.int64) + tmp5, tmp6, = tl.associative_scan(( + tmp1, + tmp3, + ), 0, _triton_cummax_helper_fn) + tl.store(out_ptr0 + offset, tmp5) + tl.store(out_ptr1 + offset, tmp6) + + a = torch.randn((64, ), device=device) > 0 + values = torch.empty((64, ), dtype=torch.bool, device=device) + indices = torch.empty((64, ), dtype=torch.int64, device=device) + ref = torch.cummax(a, dim=0) + + triton_[(1, )](a, values, indices, 64) + torch.testing.assert_close(ref.values, values) + torch.testing.assert_close(ref.indices, indices) diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py new file mode 100644 index 000000000000..1fd443db967a --- /dev/null +++ b/python/test/unit/cpu/test_math.py @@ -0,0 +1,163 @@ +import inspect +import os +import pytest +import torch + +import triton +import triton.language as tl +from triton._C.libtriton import llvm +from triton.language.extra import libdevice +from itertools import chain, product + + +def get_native_vector_size_in_bits(): + """ + Returns the native vector size of the CPU. + Assuming x86 always uses "auto dispatch" with 512-bit vectors for Sleef. + """ + cpu_features = llvm.get_cpu_features() + # TODO support for arm sve w/ VLA + if "neon" in cpu_features: + return 128 + return 512 + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] +lib_prefix = { + "libsleef": "Sleef", + "libmvec": "_ZGV", +} +arch = triton.runtime.driver.active.get_current_target().arch + +vec_sizes = [1, 2, 4, 8, 16, 32, 64, 128] +scalar_sizes = [1, 4, 16, 64] + + +def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False): + # Check generated code calls vector math function + # FP16 and BF16 are cast to FP32 for math ops + elem_size = 8 if dtype_str == "float64" else 4 + data_size = size * elem_size + + vec_size = get_native_vector_size_in_bits() / 8 # bytes + # 128-bit vector is the smallest supported by Sleef for both x86 and arm + smallest_vec_size = 128 / 8 # bytes + if data_size > vec_size: + num_vec_calls = data_size // vec_size + elif data_size >= smallest_vec_size: + num_vec_calls = 1 + else: + num_vec_calls = 1 if is_always_extern else 0 + assert meta.asm["asm"].count(lib_prefix[vec_lib]) == num_vec_calls + + +@pytest.mark.parametrize("vec_lib, size", + chain(product(["libsleef", "libmvec"], vec_sizes), product([None], scalar_sizes))) +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", ["cos", "exp", "exp2", "log", "log2", "sin"]) +def test_tensor_math_fn(vec_lib, dtype_str, math_fn, size, device): + if not is_cpu(): + pytest.skip("This test is CPU-specific") + if vec_lib == "libmvec" and arch != "x86_64": + pytest.skip("Vectorized libm calls are supported for x86 target only.") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(x, MATH_FN)() + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) + ref = getattr(src, math_fn)() + torch.testing.assert_close(ref, res) + + if vec_lib is not None: + check_num_vec_calls(meta, vec_lib, dtype_str, size) + + +@pytest.mark.parametrize("vec_lib, size", + chain(product(["libsleef", "libmvec"], vec_sizes), product([None], scalar_sizes))) +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", [ + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "ceil", "cos", "cosh", "erf", "exp", "exp2", "expm1", + "floor", "fmod", "isnan", "isinf", "log", "log1p", "log2", "log10", "pow", "rsqrt", "signbit", "sin", "sinh", + "sqrt", "tan", "tanh", "trunc" +]) +def test_libdevice_math_fn(vec_lib, dtype_str, math_fn, size, device): + if not is_cpu(): + pytest.skip("This test is CPU-specific") + if vec_lib == "libmvec" and arch != "x86_64": + pytest.skip("Vectorized libm calls are supported for x86 target only.") + if math_fn in {"ceil", "fmod", "pow"}: + if vec_lib != "libsleef": + pytest.skip("extern_elementwise only supports libsleef") + if dtype_str not in {"float32", "torch.float64"}: + pytest.skip(f"{math_fn} only supports fp32, fp64") + + @triton.jit + def unary_kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(libdevice, MATH_FN)(x) + tl.store(dst + idxs, y) + + @triton.jit + def binary_kernel(x_ptr, y_ptr, out_ptr, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + idxs) + y = tl.load(y_ptr + idxs) + result = getattr(libdevice, MATH_FN)(x, y) + tl.store(out_ptr + idxs, result) + + signature = inspect.signature(getattr(libdevice, math_fn)) + num_params = len(signature.parameters) + inputs = [torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) for _ in range(num_params)] + # Customize inputs + if math_fn == "acosh": + inputs[0] = inputs[0].abs() + 1 + if math_fn == "isnan" or math_fn == "isinf": + indices = torch.randint(low=0, high=size, size=(size // 2, ), device=device) + src = inputs[0] + for i in indices: + if math_fn == "isnan": + src[i] = float("nan") + else: + src[i] = float(("+" if i % 2 else "-") + "inf") + + # Generate reference output + if math_fn == "cbrt": + ref = inputs[0].pow(1 / 3) + else: + ref = getattr(inputs[0], math_fn)(*inputs[1:]) + + res = torch.empty(inputs[0].shape, dtype=ref.dtype, device=device) + kernel = unary_kernel if num_params == 1 else binary_kernel + meta = kernel[(1, )](*inputs, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) + torch.testing.assert_close(ref, res) + + if vec_lib is None: + return + + # These are not implemented via extern library calls + native_impls = { + "libmvec": {"expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt", "trunc"}, + "libsleef": {"isnan", "isinf", "rsqrt", "signbit"}, + } + # These are always implemented with extern library calls + always_extern = {"ceil", "fmod", "pow"} + if math_fn not in native_impls[vec_lib]: + check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=math_fn in always_extern) + else: + assert meta.asm["asm"].count(lib_prefix[vec_lib]) == 0 diff --git a/python/test/unit/cpu/test_opt.py b/python/test/unit/cpu/test_opt.py new file mode 100644 index 000000000000..32eed6fb7a99 --- /dev/null +++ b/python/test/unit/cpu/test_opt.py @@ -0,0 +1,82 @@ +import os +import pytest +import torch + +import triton +import triton.language as tl + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +def is_x86(): + return is_cpu() and \ + triton.runtime.driver.active.get_current_target().arch == "x86_64" + + +def test_scalar_pointer_arith(device): + + @triton.jit + def kernel(src, dst, BLOCK_SIZE: tl.constexpr): + offs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offs) + tl.store(dst + offs, x) + + src = torch.rand((128, ), dtype=torch.float32, device=device) + res = torch.empty_like(src) + meta = kernel[(1, )](src, res, BLOCK_SIZE=128) + assert (src == res).all() + + # Check TTCIR doesn't have pointer extraction from a tensor. + ttcir = meta.asm["ttcir"] + assert ttcir.count("extract") == 0 + + +@pytest.mark.parametrize("size", [32, 128, 135]) +@pytest.mark.parametrize("tile_size", [16]) +def test_optimize_tile_mask(size, tile_size, device): + + @triton.jit + def kernel(src, dst, size, TILE_SIZE: tl.constexpr): + for i in range(0, tl.cdiv(size, TILE_SIZE)): + offs = tl.arange(0, TILE_SIZE) + i * TILE_SIZE + mask = offs < size + x = tl.load(src + offs, mask=mask, other=0) + tl.store(dst + offs, x, mask=mask) + + src = torch.rand((size, ), dtype=torch.float32, device='cpu') + res = torch.empty_like(src) + meta = kernel[(1, )](src, res, size, TILE_SIZE=tile_size) + assert (src == res).all() + + # Check number of masked loads and stores. + tttcir = meta.asm["tttcir"] + masked_loads = tttcir.count("maskedload") + masked_stores = tttcir.count("maskedstore") + if size % tile_size == 0: + assert masked_loads == 0 + assert masked_stores == 0 + else: + assert masked_loads == 1 + assert masked_stores == 1 + + +# Regression test for compilation failure in masks optimization +def test_vec_cdiv(device): + + @triton.jit + def kernel(in_ptr, out_ptr): + offs = tl.arange(0, 16) + x = tl.load(in_ptr + offs) + res = (x + 15) // 16 + tl.store(out_ptr + offs, res) + + arg0 = torch.zeros((16, ), dtype=torch.int32) + arg1 = torch.empty_like(arg0) + kernel[(1, )](arg0, arg1) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 9695a5e47eb1..23065953d65b 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -4,7 +4,9 @@ import triton import triton.language as tl from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) -from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma +from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg + +from typing import Optional def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): @@ -27,9 +29,11 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) -@requires_tma @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimetal_descriptor_load(byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + device = "cuda" SIZE = 128 @@ -80,11 +84,13 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) -@requires_tma @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + device = "cuda" M, N, K = 8192, 8192, 1024 torch.manual_seed(42) @@ -230,3 +236,305 @@ def test_device_tensormap1d(dtype_str): # Check results are correct torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(inp_copy)) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_tensor_descriptor_load(dtype_str): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + assert desc.shape[0] == M + assert desc.shape[1] == N + assert desc.strides[0] == N + assert desc.strides[1] == 1 + assert desc.block_shape == [M_BLOCK, N_BLOCK] + block = desc.load([M_BLOCK, 2 * N_BLOCK]) + idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :] + tl.store(out_ptr + idx, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + M, N = 32, 128 + inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str) + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M_BLOCK, N_BLOCK)) + + kernel[(1, )](out, inp, M, N, M_BLOCK, N_BLOCK) + + expect = unwrap_tensor(inp)[1 * M_BLOCK:2 * M_BLOCK, 2 * N_BLOCK:3 * N_BLOCK] + torch.testing.assert_close(expect, unwrap_tensor(out)) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_tensor_descriptor_store(dtype_str): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + midx = moffset + tl.arange(0, M_BLOCK)[:, None] + nidx = noffset + tl.arange(0, N_BLOCK)[None, :] + idx = midx * N + nidx + + val = tl.load(a_ptr + idx) + + desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + assert desc.shape[0] == M + assert desc.shape[1] == N + assert desc.strides[0] == N + assert desc.strides[1] == 1 + assert desc.block_shape == [M_BLOCK, N_BLOCK] + desc.store([moffset, noffset], val) + + M, N = 32, 128 + inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str) + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 * (grid_m * grid_n) + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + + +@triton.jit(noinline=True) +def tensor_descriptor_in_function_helper(out_ptr, in_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + in_desc = tl._experimental_make_tensor_descriptor( + in_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + value = in_desc.load([moffset, noffset]) + out_desc.store([moffset, noffset], value.abs()) + + +@requires_tma +def test_tensor_descriptor_in_function(): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + tensor_descriptor_in_function_helper(out_ptr, a_ptr, M, N, M_BLOCK, N_BLOCK) + + M, N = 32, 128 + inp = torch.randn((M, N), device="cuda") + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 2 * 128 * (grid_m * grid_n) + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + expect = inp.abs() + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(expect, out) + + +@triton.jit +def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_k, offs_bn]) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_SIZE_K + accumulator = accumulator.to(a_desc.dtype) + c_desc.store([offs_am, offs_bn], accumulator) + + +@requires_tma +@pytest.mark.parametrize("num_stages", [1, 4]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) +def test_experimental_make_tensor_descriptor_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): + device = "cuda" + M, N, K = 8192, 8192, 1024 + torch.manual_seed(42) + A = torch.randn((M, K), dtype=torch.float16, device=device) + B = torch.randn((K, N), dtype=torch.float16, device=device) + C = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 3 * 128 * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + kernel = matmul_kernel_make_tensor_desciptor[grid]( + A, + B, + C, + M, + N, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_warps=8, + num_stages=num_stages, + ) + ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + if BLOCK_M >= 64 and BLOCK_N >= 64: + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + + +@triton.jit +def kernel_make_tensor_desciptor_loop_carried(a_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + # Test that descriptors work with + pid = tl.program_id(0) + moffset = MBLOCK * pid + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + for i in range(0, N, NBLOCK): + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + if i % (3 * NBLOCK) == 0: + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a = a_desc.load([moffset, i]) + a_desc.store([moffset, i], a + 10) + + n = 0 + while n < N: + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + if n % (3 * NBLOCK) == 0: + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a = a_desc.load([moffset, n]) + a_desc.store([moffset, n], a + 5) + + n += NBLOCK + + +@requires_tma +def test_experimental_make_tensor_descriptor_loop_carried(): + device = "cuda" + M, N = 8192, 8192 + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device) + MBLOCK, NBLOCK = 8, 128 + grid = (triton.cdiv(M, MBLOCK), ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + ref_out = A + 15 + kernel = kernel_make_tensor_desciptor_loop_carried[grid]( + A, + M, + N, + MBLOCK, + NBLOCK, + ) + torch.testing.assert_close(ref_out, A) + assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] diff --git a/python/test/unit/instrumentation/test_gpuhello.py b/python/test/unit/instrumentation/test_gpuhello.py index 413c3f642405..bdc6ca90742c 100644 --- a/python/test/unit/instrumentation/test_gpuhello.py +++ b/python/test/unit/instrumentation/test_gpuhello.py @@ -31,7 +31,6 @@ def kernel3(BLOCK_SIZE: tl.constexpr): def func(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) kernel1[grid](BLOCK_SIZE=1024) @@ -39,10 +38,10 @@ def func(x: torch.Tensor, y: torch.Tensor): kernel3[grid](BLOCK_SIZE=1024) -def test_op(capfd): +def test_op(capfd, device: str): size = 98432 - x = torch.rand(size, device='cuda') - y = torch.rand(size, device='cuda') + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) func(x, y) stdout, stderr = capfd.readouterr() if 'LLVM_PASS_PLUGIN_PATH' in os.environ: diff --git a/python/test/unit/language/conftest.py b/python/test/unit/language/conftest.py index 091f9ea41e7f..44615b8b883b 100644 --- a/python/test/unit/language/conftest.py +++ b/python/test/unit/language/conftest.py @@ -3,3 +3,4 @@ def pytest_configure(config): config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") + config.addinivalue_line("markers", "cpu: indicate whether test is supported on cpu") diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index 30a4745db1d9..a8fb54558328 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -35,10 +35,12 @@ def kernel_print(X, Y, BLOCK: tl.constexpr): @triton.jit -def kernel_device_print_scalar(SCALAR): +def kernel_device_print_scalars(SCALAR, INT, FLOAT): x = tl.load(SCALAR) # Triton should add a space after this prefix. print("x:", x) + print("int:", INT) + print("float:", FLOAT) @triton.jit @@ -91,19 +93,22 @@ def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): def test_print(func: str, data_type: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. + # These values should match with test_print in test_subprocess.py. + N = 128 + SCALAR = 42 + # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple # threads printing duplicated messages due to broadcasting. Improve print op lowering logic # to filter out duplicated data range. - num_warps = N // get_current_target_warp_size() + num_warps = N // (get_current_target_warp_size() if device != "cpu" else 1) x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) y = torch.zeros((N, ), dtype=x.dtype, device=device) if func == "device_print": kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "device_print_scalar": - scalar = torch.tensor(42, dtype=x.dtype, device=device) - kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_scalars": + scalar = torch.tensor(SCALAR, dtype=x.dtype, device=device) + kernel_device_print_scalars[(1, )](scalar, SCALAR, 3.14, num_warps=num_warps) elif func == "device_print_negative": x = -x kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) @@ -129,16 +134,16 @@ def test_print(func: str, data_type: str, device: str): elif func == "device_print_pointer": kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) else: - assert f"Unknown kernel: {func}" - + assert False, f"Unknown kernel: {func}" if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ func != "print_multiple_args" and func != "device_print_multiple_args" and \ - func != "device_print_pointer" and func != "device_print_scalar": + func != "device_print_pointer" and func != "device_print_scalars": assert_close(y, x) # Wait until driver complete all the jobs for the device_print, especially test_subprocess # require this which captures stdout when child exits. - getattr(torch, device).synchronize() + if device != "cpu": + torch.cuda.synchronize() if __name__ == "__main__": diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 8e84a9f82a08..f9591b6505e2 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -3,33 +3,43 @@ import triton import triton.language as tl -from test_core import check_type_supported +from test_core import check_type_supported, is_cpu @triton.jit -def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr, + TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr): pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + if TEST_LOWER_BOUND: + offset = -N + elif TEST_UPPER_BOUND: + offset = N # We only copy half of the data to see if the padding works - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - if padding_option is None: + if PADDING_OPTION is None: a = tl.load(a_block_ptr, boundary_check=(0, )) else: - a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION) tl.store(b_block_ptr, a, boundary_check=(0, )) @pytest.mark.interpreter -@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # - (dtypes_str, n, padding) +@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ # + (dtypes_str, n, padding, boundary_check) # for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), ("float32", "float32"), ("bfloat16", "bfloat16")) for n in (64, 128, 256, 512, 1024) for padding in (None, "zero", "nan") # + for boundary_check in (None, "lower", "upper") ]) -def test_block_copy(dtypes_str, n, padding_option, device): +def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): + if is_cpu() and boundary_check == "lower": + pytest.skip("Lower boundary check is NYI for CPU") + src_dtype_str = dtypes_str[0] dst_dtype_str = dtypes_str[1] src_dtype = getattr(torch, src_dtype_str) @@ -45,13 +55,35 @@ def test_block_copy(dtypes_str, n, padding_option, device): b = torch.zeros((n, ), device=device, dtype=dst_dtype) grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) - block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option, + TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper") a.to(dst_dtype) - assert torch.all(a[0:n // 2] == b[0:n // 2]) - if padding_option == "zero": - assert torch.all(b[n // 2:n] == 0) - elif padding_option == "nan": - assert torch.all(torch.isnan(b[n // 2:n])) + if (boundary_check == "lower") or (boundary_check == "upper"): + assert torch.all(b == 0) + else: + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + elif padding_option == "nan": + assert torch.all(torch.isnan(b[n // 2:n])) + + +def test_block_copy2d(device): + + @triton.jit + def kernel(in_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, BLOCK_M: tl.constexpr): + block_offset = tl.program_id(0) * BLOCK_M + in_block_ptr = tl.make_block_ptr(base=in_ptr, shape=(M, N), strides=(N, 1), offsets=(block_offset, 0), + block_shape=(BLOCK_M, N), order=(1, 0)) + out_block_ptr = tl.make_block_ptr(base=out_ptr, shape=(M, N), strides=(N, 1), offsets=(block_offset, 0), + block_shape=(BLOCK_M, N), order=(1, 0)) + x = tl.load(in_block_ptr) + tl.store(out_block_ptr, x) + + inp = torch.randn((256, 16), device=device, dtype=torch.float32) + res = torch.empty_like(inp) + kernel[(16, )](inp, res, M=16, N=16, BLOCK_M=16) + assert (inp == res).all() @triton.jit @@ -90,6 +122,9 @@ def matmul_no_scf_with_advance_kernel( # ]) def test_block_ptr_matmul_no_scf(shape, num_warps, device): m, n, k = shape + if is_cpu(): + # FIXME: fix compilation time for bigger shapes on CPU + m = n = 16 a = torch.randn((m, k), device=device, dtype=torch.float16) b = torch.randn((k, n), device=device, dtype=torch.float16) c = torch.empty((m, n), device=device, dtype=torch.float32) @@ -103,5 +138,9 @@ def test_block_ptr_matmul_no_scf(shape, num_warps, device): stride_cm=c.stride(0), stride_cn=c.stride(1), # BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # num_warps=num_warps) - golden = torch.matmul(a, b) + if is_cpu(): + # torch.matmul not implemented for Half float (float16) cpu + golden = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + else: + golden = torch.matmul(a, b) torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 12c3997ec7c4..b9837f5a47e2 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,22 +7,11 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback +from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300 -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' - - -def is_cuda(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def is_hip(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" - - -def is_on_mi300(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') +def is_cpu(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cpu" def test_err_undefined_variable(): @@ -367,10 +356,12 @@ def test_fp8_support(dtype): if cc >= (8, 9): supported_dtypes.append(tl.float8e4nv) elif is_hip(): - if is_on_mi300(): - supported_dtypes += [tl.float8e4b8, tl.float8e5b16] + if is_hip_mi300(): + supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] elif is_interpreter(): supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15] + elif is_cpu(): + supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv] @triton.jit def dtype_kernel(dtype: tl.constexpr): diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 723a15fe847f..686c6fe0deb2 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -1,24 +1,14 @@ # fmt: off -import os import numpy as np import torch import pytest import triton import triton.language as tl -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_cpu -def is_cuda(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" - -def is_hip(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" - -def is_on_mi300(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') def matching_int(dtype): if dtype.primitive_bitwidth == 8: @@ -281,9 +271,15 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia ('float8e5b16', 'float16'), ]) def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + # On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300. + if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()): + pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") + if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) - or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip()) - or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))): + or (src_dtype in ('float8e4b15') and is_hip()) + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or (is_hip() and not is_hip_mi300()))) + or (src_dtype in ('float8e4b8', 'float8e4b15') and is_cpu())): # If the dtype should error out in the given device, we assert that and return with pytest.raises(triton.CompilationError, match="not supported in this architecture"): launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) @@ -327,16 +323,22 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device): ('float16', 'float8e4b8', 'rtne', 0x5b80), ]) def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + if is_cpu() and dst_dtype not in ['float8e5', 'float8e4nv', 'float8e5b16']: + # TODO check if 'float8e4b15' downcast is fine for cpu if it will enable in this test + pytest.skip(f"Conversion from {src_dtype} to {dst_dtype} is not supported on CPU") if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") - if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or (is_cuda() and torch.cuda.get_device_capability(0) < (9, 0))): pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") - if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or (is_hip() and not is_hip_mi300())): pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + if dst_dtype == 'float8e4nv' and is_hip(): + pytest.skip(f"{dst_dtype} downcast not supported in HIP") + # dtype : (exponent_bits, mantissa_bits, exponent_bias) stuff = { 'float16': (5, 10, 15), diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3013bbf53177..9d6ea7997456 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5,7 +5,7 @@ from typing import Optional import math import textwrap -import tempfile +import pathlib import numpy as np import pytest @@ -23,12 +23,19 @@ int_dtypes, uint_dtypes, float_dtypes, + float_dtypes_with_bfloat16, dtypes, dtypes_with_bfloat16, is_cuda, is_interpreter, + is_hopper, is_hip, + is_hip_cdna, + is_hip_mi200, + is_hip_mi300, + is_xpu, get_arch, + is_cpu, torch_float8_dtypes, torch_dtypes, numpy_random, @@ -52,7 +59,7 @@ def promotion_numpy_2_0(): # num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] num_ctas_list = [1] -GPU_DIALECT = "triton_gpu" +GPU_DIALECT = "ttg" if is_interpreter(): THREADS_PER_WARP = 1 elif is_hip(): @@ -66,6 +73,11 @@ def _bitwidth(dtype: str) -> int: return int(re.search(r'(\d+)$', dtype).group(1)) +def _dtype(dtype: str) -> str: + # ex.: "int64" -> "int" + return re.match(r'([a-zA-Z]+)', dtype).group(0) + + def patch_kernel(template, to_replace): if is_interpreter(): local_namespace = {} @@ -101,6 +113,8 @@ def check_type_supported(dtype, device): if is_interpreter(): if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: pytest.skip("bfloat16 is not supported in the interpreter") + if dtype == 'float8e4b15' and is_cpu(): + pytest.skip("float8e4b15 not supported on CPU") class MfmaLayout: @@ -139,6 +153,17 @@ def __str__(self): return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -174,7 +199,12 @@ def is_layout_applicable(layout) -> bool: if layout in common_layouts: return True elif is_cuda(): - return isinstance(layout, MmaLayout) + mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout + if not isinstance(mma_layout, MmaLayout): + return False + if mma_layout.version[0] >= 3 and not is_hopper(): + return False + return True elif is_hip(): target_arch = triton.runtime.driver.active.get_current_target().arch if "gfx11" in target_arch: @@ -193,6 +223,7 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -207,6 +238,16 @@ def kernel(X, SIZE: tl.constexpr): kernel[(1, )](x, SIZE=SIZE, num_warps=4) +@pytest.mark.cpu +def test_empty_kernel_scalar_arg(device): + + @triton.jit + def kernel(x): + pass + + kernel[(1, )](2) + + # generic test functions def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): check_type_supported(dtype_x, device) # early return if dtype_x is not supported @@ -267,7 +308,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, - y_low=None, y_high=None, filter_y=None, test_broadcast=True, test_scalar=True): + x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, + test_scalar=True): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -312,7 +354,7 @@ def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): # inputs rs = RandomState(17) - x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) if filter_y: y[filter_y(y)] = 1 @@ -346,7 +388,7 @@ def do_test(x, y, kernel_fn): z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) err_msg = f"{expr}, {kernel_fn.__name__}" - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) def get_scalar(x, dtype, low, high, filter): # If dtype is int, don't choose a huge number for the scalar @@ -380,30 +422,35 @@ def get_scalar(x, dtype, low, high, filter): do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) -def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: - # FIXME For large x, we are casting x to a floating point where it does not fit - # For small y, we are computing floor(div(float(x), y)) which may not fit - return (dtype_x, dtype_y) in [ - ('int32', 'bfloat16'), - ('int32', 'float16'), - ('int32', 'float32'), - ('int64', 'bfloat16'), - ('int64', 'float16'), - ('int64', 'float32'), - ('int64', 'float64'), - ('uint16', 'bfloat16'), - ('uint16', 'float16'), - ('uint16', 'float32'), - ('uint32', 'bfloat16'), - ('uint32', 'float16'), - ('uint32', 'float32'), - ('uint64', 'bfloat16'), - ('uint64', 'float16'), - ('uint64', 'float32'), - ('uint64', 'float64'), - ] +def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: + """ + Limit min/max values for integral types for mod values. Leads to + overflow/underflow when casting large integral types to floats. + """ + x_bitwidth = _bitwidth(dtype_x) + y_bitwidth = _bitwidth(dtype_y) + + # hard cap max value bit-width to 32 if 64 bit-width types + min_bitwidth = min(x_bitwidth, y_bitwidth, 32) + + # Limit max value bit-width to be one integral type less than the min bit-width + # For example: + # int64, float32 -> int16 + # uint16, float16 -> uint8 + x_dtype = _dtype(dtype_x) + max_bitwidth = max(min_bitwidth >> 1, 8) + dtype_max = x_dtype + str(max_bitwidth) + + max_info = np.iinfo(getattr(np, dtype_max)) + + # Still need to limit values here for uints + if max_bitwidth >= 16 and dtype_max in uint_dtypes: + return max_info.min, max_info.max // 4 + else: + return max_info.min, max_info.max +@pytest.mark.cpu def test_dtype_codegen(): for dtype in dtypes_with_bfloat16: full_name = f"triton.language.{dtype}" @@ -415,6 +462,7 @@ def test_dtype_codegen(): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -425,35 +473,35 @@ def test_dtype_codegen(): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): expr = f'x {op} y' - if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: - # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. - numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', - 'bfloat16'): - # Triton promotes 16-bit floating-point / and % to 32-bit because there - # are no native div or FRem operations on float16. Since we have to - # convert anyway, we may as well take the accuracy bump. - numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') else: numpy_expr = None - if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): - with pytest.raises(AssertionError, match="Not equal to tolerance"): - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) - elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + + if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) else: # skip when bfloat16, as NumPy's ref performs the computation in float32 # while Triton performs it in bfloat16 - # We also skip mod when it is ill-conditioned skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) - or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes - and _mod_operation_ill_conditioned(dtype_x, "float32"))) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) # can't divide by zero not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes # can't represent -int(max) @@ -462,13 +510,20 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) else: filter_y = None + + if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: + x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) + else: + x_low, x_high = None, None + _test_binary( dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, # fails with values where fmod(x, y) is roughly zero, but happens to # pass with the random values chosen for non-broadcast tests - test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test) + test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -495,6 +550,7 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -515,6 +571,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) +@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -551,6 +608,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -575,6 +633,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes @@ -597,6 +656,7 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -621,6 +681,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -655,6 +716,7 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- +@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -686,6 +748,7 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -701,6 +764,7 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -749,6 +813,7 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -812,6 +877,7 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -828,6 +894,7 @@ def _kernel(dst): # --------------- # test where # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -880,6 +947,7 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -924,6 +992,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -938,6 +1007,7 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -948,6 +1018,7 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -969,6 +1040,7 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -994,6 +1066,7 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1006,6 +1079,7 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1046,12 +1120,14 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) def test_abs_fp8(in_dtype, device): @@ -1063,6 +1139,8 @@ def test_abs_fp8(in_dtype, device): pytest.skip("float8e4b15 not supported on CUDA >= 9.0") if in_dtype == tl.float8e4nv and cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + elif is_cpu(): + pytest.skip('CPU not supports "fp8e4b15"') @triton.jit def abs_kernel(X, Z, SIZE: tl.constexpr): @@ -1091,6 +1169,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1111,6 +1190,9 @@ def kernel(): a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + a = tl.arange(0, 64).view(2, 4, 8) tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) @@ -1122,6 +1204,7 @@ def kernel(): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_transpose(dtype_x, device): @@ -1160,7 +1243,8 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" -# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.cpu +# TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] for d in ['int32', 'uint32', 'uint16']]) @@ -1229,6 +1313,7 @@ def tuples_fn(a, b): a * b +@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1321,6 +1406,7 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1352,6 +1438,7 @@ def kernel(X, Y, Z): # --------------- # test atomics # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_x_str, mode, sem", @@ -1432,6 +1519,7 @@ def kernel(X, Z): assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_atomic_rmw_predicate(num_ctas, device): @@ -1447,23 +1535,34 @@ def kernel(X): assert x.item() == 63 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str", [(shape, axis, num_ctas, dtype_x_str) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list - for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) + for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): shape0, shape1 = shape # triton kernel @triton.jit - def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16: + z = z.to(DTYPE) + if AXIS == 1: old = tl.atomic_add(Z + off0, z) tl.store(OLD + off0, old) @@ -1477,17 +1576,28 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) # reference results - z_ref = z + np.sum(x, axis=axis, keepdims=False) + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) z_tri = to_triton(z, device=device) old_tri = to_triton(old, device=device) - kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) + + def torch_to_triton_dtype(t): + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) np.testing.assert_equal(old_ref, to_numpy(old_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_tensor_atomic_rmw_block(num_ctas, device): @@ -1507,6 +1617,7 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): assert torch.min(x).item() == 0.0 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1548,6 +1659,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr): assert f"atom.global.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1574,6 +1686,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -1606,6 +1719,21 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + if is_cpu() and (dtype_x in torch_float8_dtypes or dtype_z in torch_float8_dtypes): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} is not supported on CPU.') + + # fptrunc fp32->fp16 is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/95274 + # TODO: remove the change after the bug is fixed. + if is_cpu() and dtype_x == "float32" and dtype_z == "float16": + size = 512 + + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_x == 'bfloat16' and size > 128: + size = 128 + torch.manual_seed(0) # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): @@ -1673,6 +1801,7 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) @@ -1697,51 +1826,39 @@ def kernel(X, Y, Z, N: tl.constexpr): assert z.unique().size(0) == z.size(0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant(dtype_str, num_ctas, device): +def test_store_constant(num_ctas, dtype_str, constant_field, device): check_type_supported(dtype_str, device) - """Tests that boolean True is stored as 1""" @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - output = GENERATE_TEST_HERE + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False tl.store(output_ptr + offsets, output, mask=mask) - triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) - assert torch.all(output == ref) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) - -@pytest.mark.interpreter -@pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_store_constant_default_dtype(num_ctas, device): - """Tests that boolean True is stored as 1""" - - @triton.jit - def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - value = 1 - output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) - tl.store(output_ptr + offsets, output, mask=mask) - - block_size = 128 - ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) - output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) - kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) - - assert torch.all(output == ref) + if constant_field == "value": + assert torch.all(output == ref) + else: + assert torch.all(output == 0) +@pytest.mark.cpu def test_load_store_same_ptr(device): @triton.jit() @@ -1760,6 +1877,7 @@ def kernel(in_out_ptr): assert torch.all(x == 2) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['int32']) def test_umulhi(dtype_str, device): @@ -1797,6 +1915,7 @@ def umulhi32(a, b): np.testing.assert_equal(z_ref, to_numpy(z_tri)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join(device): @@ -1817,6 +1936,7 @@ def kernel(X, Y, Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_scalars(device): @@ -1836,6 +1956,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([42, 100], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_with_mma(device): @@ -1874,6 +1995,7 @@ def kernel(Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_interleave_scalars(device): @@ -1889,6 +2011,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([10, 20], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split(device): @@ -1911,6 +2034,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr): np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split_to_scalar(device): @@ -1970,6 +2094,7 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): return output +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) def test_convert_float16_to_float32(in_dtype, device): @@ -2003,6 +2128,7 @@ def deserialize_fp8(np_data, in_dtype): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_max_returns_zero(device): # Simple test with a tl.max call that returns 0. The interpreter had a bug @@ -2029,6 +2155,7 @@ def get_reduced_dtype(dtype_str, op): return dtype_str +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ 'min', @@ -2139,6 +2266,7 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -2281,17 +2409,24 @@ def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): check_type_supported(dtype_str, device) if dtype_str == 'bfloat16': - if op == 'cummax': + if is_cuda() and op == 'cummax': pytest.skip("bfloat16 compare not suppoted before sm90") if op == 'linear_recurrence': pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_str == 'bfloat16': + shape = (min(shape[0], 128), min(shape[1], 128)) + # triton kernel @triton.jit def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): @@ -2438,6 +2573,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) def test_histogram(M, N, device): @@ -2448,6 +2584,9 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): offset2 = tl.arange(0, N) x = tl.load(x_ptr + offset1) z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias tl.store(z_ptr + offset2, z) torch.manual_seed(17) @@ -2461,6 +2600,7 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): assert (z_torch == z).all() +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['sum', 'max', 'min']) @pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) @@ -2506,7 +2646,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) - if not is_interpreter(): + if not is_interpreter() and not is_cpu(): assert h.asm['ttgir'].count( '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) @@ -2517,27 +2657,40 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_scan_layouts(M, N, src_layout, axis, device): +@pytest.mark.parametrize("add_overflow_check", [False, True]) +def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): + + overflow_check = """ + %17 = arith.extsi %arg2 : i32 to i64 + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.addi %17, %18 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %20 = arith.cmpi slt, %19, %i32.max : i64 + %21 = arith.cmpi sge, %19, %i32.min : i64 + %22 = arith.andi %20, %21 : i1 + tt.assert %22, "overflow detected" : i1 + """ ir = f""" #blocked = {src_layout} - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> - %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ ^bb0(%arg2: i32, %arg3: i32): - %16 = arith.addi %arg2, %arg3 : i32 + %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> @@ -2550,10 +2703,10 @@ def test_scan_layouts(M, N, src_layout, axis, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + rs = RandomState(17) x = rs.randint(-100, 100, (M, N)).astype('int32') @@ -2599,9 +2752,11 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) -@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), + ("float16", False)]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device, + tmp_path: pathlib.Path): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") @@ -2613,6 +2768,18 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce if isinstance(src_layout, MmaLayout) and src_layout.version == 3: src_layout[2] = 16 if dtype_str == "float16" else 8 + overflow_check = """ + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.extsi %arg4 : i32 to i64 + %20 = arith.addi %18, %19 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %21 = arith.cmpi slt, %20, %i32.max : i64 + %22 = arith.cmpi sge, %20, %i32.min : i64 + %23 = arith.andi %21, %22 : i1 + tt.assert %23, "overflow detected" : i1 + """ + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # @@ -2645,7 +2812,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce f""" %14 = "tt.reduce"(%13) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} tt.store %arg2, %14 : !tt.ptr<{ty}> @@ -2657,10 +2824,10 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> %15 = "tt.reduce"(%14) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) - %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %16 = ttg.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> tt.return @@ -2673,7 +2840,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce #blocked = {blocked} #src = {src_layout} #one_d_layout = {one_d_layout} - module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> @@ -2690,15 +2857,14 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> %13 = "tt.reduce"(%12) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) @@ -2728,7 +2894,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce @pytest.mark.parametrize("M", [32, 64, 128, 256]) @pytest.mark.parametrize("src_layout", layouts) -def test_store_op(M, src_layout, device): +def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): ir = f""" #src = {src_layout} @@ -2749,10 +2915,9 @@ def test_store_op(M, src_layout, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - store_kernel = triton.compile(f.name) + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, 1)).astype('float32') @@ -2779,12 +2944,12 @@ def test_store_op(M, src_layout, device): @pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) @pytest.mark.parametrize("src_dim", [0, 1]) @pytest.mark.parametrize("dst_dim", [0, 1]) -def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): ir = f""" #dst = {dst_layout} #src = {src_layout} - module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> @@ -2799,10 +2964,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, )).astype('int32') @@ -2840,7 +3004,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("op", ["sum", "max"]) @pytest.mark.parametrize("first_axis", [0, 1]) -def test_chain_reduce(M, N, src_layout, op, device, first_axis): +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): op_str = "" if op == "sum": @@ -2854,7 +3018,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): tt.reduce.return %14 : i32""" ir = f""" #src = {src_layout} - module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> @@ -2881,10 +3045,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') @@ -2903,6 +3066,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_generic_reduction(device): @@ -2934,6 +3098,7 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) # TODO: bfloat16 @@ -2993,6 +3158,7 @@ def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constex assert 'st.global.v4' in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) @@ -3016,6 +3182,7 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) @@ -3068,6 +3235,7 @@ def convert_fp8_to_fp32(x, device, dtype_str): assert "Unsupported float8 dtype" +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", @@ -3098,6 +3266,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if is_interpreter(): if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") + elif is_cpu(): + # This test kernel runs in a single thread and can take a long time + # for bigger sizes with the current codegen on CPU. Limit input sizes + # by default to get more reasonable tests execution time. + if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1': + M = min(M, 32 if epilogue == "chain-dot" else 64) + N = min(N, 32 if epilogue == "chain-dot" else 64) + K = min(K, 16 if epilogue == "chain-dot" else 32) else: if is_cuda(): capability = torch.cuda.get_device_capability() @@ -3300,41 +3476,53 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", - [(M, N, K, col_a, col_b, type_a, type_b, 4) +@pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, 4, mma, kpack) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) - for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"]]) -def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): - if not is_cuda(): - pytest.skip("scaled_dot only supported on CUDA") - else: + for rhs_scale in [False, True] + for normal_type in ["e2m1", "e4m3", "e5m2"] + for mxfp_type in ["e4m3", "e5m2", "bf16"] + for mma in ([32, 16] if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, num_warps, mma, kpack, device): + if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if not is_hip_cdna(): + pytest.skip("scaled_dot only implemented for HIP CDNA") + if "e4m3" in (normal_type, mxfp_type) and not is_hip_mi300(): + pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit - def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): - tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") - IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" - DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_a1 b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] - a = tl.load(a_ptr) b = tl.load(b_ptr) - a_scale = tl.load(scale_a_ptr) - c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + if a_scale is not None: + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + a_scale = tl.load(scale_a_ptr) + if b_scale is not None: + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + b_scale = tl.load(scale_b_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] tl.store(out_ptr, c.to(tl.bfloat16)) @@ -3386,17 +3574,17 @@ def mxfp_to_bf16_kernel( x_bf16 = x_f8.to(tl.bfloat16) else: # e2m1 - em0 = x & 0x70 - em1 = x & 0x7 - x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) - x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + em0 = x & 0x7 + em1 = x & 0x70 + x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4) + x1 = (em1.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << (8)) # Three cases: # 1) x is normal and non-zero: Correct bias - x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) - x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1) # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 - x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) - x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1) # 3) x is zero, do nothing x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) # Multiplication preserves infs and NaNs in x_bf16 @@ -3407,22 +3595,31 @@ def mxfp_to_bf16_kernel( offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) - def dot_scale_ref(x, scale, y, type_x, type_y): - e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] - type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] - - comp_dtype = torch.bfloat16 - - x = x.contiguous() - x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) - - N = x_upcast.numel() - BLOCK_SIZE = 512 - grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) - mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) - assert x_upcast.isfinite().all() - - y_upcast = y.view(type_fp8_y).to(comp_dtype) + def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): + + def upcast(v, scale, type, transposed): + comp_dtype = torch.bfloat16 + if scale is None: + type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type] + return v.view(type).to(comp_dtype) + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] + # Packing is always on the K dimension so we transpose before upcasting then transpose back. + if transposed: + v = v.mT.contiguous() + v = v.contiguous() + v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + N = v_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, + num_warps=num_warps) + assert v_upcast.isfinite().all() + if transposed: + v_upcast = v_upcast.mT + return v_upcast + + x_upcast = upcast(x, scale_x, type_x, False) + y_upcast = upcast(y, scale_y, type_y, True) class AccumulateInFp32: @@ -3434,28 +3631,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value with AccumulateInFp32(): - return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + return torch.matmul(x_upcast, y_upcast) torch.manual_seed(0) - def create_uint8(shape, col_major=False, max_val=255): + def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if ty == "bf16": + ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**15, 2**15 - 1) + else: + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: ret = ret.mT return ret - DIV_FACTOR = 2 if type_a == "e2m1" else 1 - x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) - y = create_uint8((K, N), col_major=col_b) + type_a = normal_type if not rhs_scale else mxfp_type + type_b = mxfp_type if not rhs_scale else normal_type + + DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 + x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) + y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow - m_bytes = int(type_a[1]) - bias_type_a = 1 << (m_bytes - 1) - 1 - max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a - scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + # Max scale= 2**15 + scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) + scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15) + if rhs_scale: + scale_x = None + else: + scale_y = None def make_finite(x, dtype): # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and @@ -3470,31 +3678,39 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) - + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, - num_warps=num_warps) - - z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) - - # generous rtol as we are sampling the whole range of floats - torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, + **kernel_kwargs) + z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values + # to zero. Detailed info is at: + # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 2e-4 if is_hip_mi200() else 1e-5 + rtol = 2e-2 if is_hip_mi200() else 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) for B in [1, 2, 4, 8] for num_warps in [1, 2, 4, 8, 16] - for BLOCK_M, BLOCK_N in [(32, 32)] + for BLOCK_M, BLOCK_N in [(32, 32) if not is_cpu() else (4, 4)] for M, N, K in [(64, 64, 64), (32, 32, 32)] for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + @@ -3510,10 +3726,14 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + elif is_cpu(): + if out_dtype_str == "float16": + pytest.skip("Test is skipped due to float16 accuracy issue") + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" - if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": + if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32" and not is_cpu(): if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( torch.cuda.current_device())["max_shared_mem"] < 131072: pytest.skip( @@ -3613,6 +3833,7 @@ def kernel( np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) +@pytest.mark.cpu @pytest.mark.parametrize('in_dtype', ['float32']) def test_dot_mulbroadcasted(in_dtype, device): if is_cuda(): @@ -3662,11 +3883,12 @@ def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.co # value is in rowmajor. But MMAv3 requires its second operand is in colmajor # because transpose is not supported for MMAv3 with float32 input. if capability[0] >= 9: - assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + assert re.search(r"ttg.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None else: - assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) @pytest.mark.parametrize("shape", [(), (1, ), (128, )]) @@ -3706,6 +3928,7 @@ def kernel_dynamic(out, val, dtype: tl.constexpr): assert torch.all(out_dynamic == 2) +@pytest.mark.cpu @pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), ('float("-inf")', "f32"), ('float("nan")', "f32"), ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) @@ -3730,6 +3953,7 @@ def pass_const(a, b, choose_b): return a +@pytest.mark.cpu @pytest.mark.parametrize("choose_const", [True, False]) @pytest.mark.parametrize("constexpr", [True, False]) @pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) @@ -3797,6 +4021,7 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co assert torch.all(input == output) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['float32', 'float16']) def test_dot_without_load(dtype_str, device): @@ -3812,7 +4037,11 @@ def _kernel(out): kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) - out_ref = torch.matmul(a, b) + if is_cpu() and dtype_str == "float16": + # torch.matmul not implemented for Half float (float16) cpu + out_ref = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=getattr(torch, dtype_str), device=device) + else: + out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) kernel[(1, )](out) assert torch.all(out == out_ref) @@ -3823,6 +4052,7 @@ def _kernel(out): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -3846,6 +4076,7 @@ def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) for dtype_str in torch_dtypes @@ -3884,6 +4115,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): torch.testing.assert_close(output, reference_out) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) @pytest.mark.parametrize("mask_val", [True, False]) @@ -3913,6 +4145,7 @@ def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.co # Testing masked loads with a copy to shared memory. # FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device): @@ -3925,11 +4158,11 @@ def test_masked_load_shared_memory(dtype, device): in1 = torch.rand((M, K), dtype=dtype, device=device) in2 = torch.rand((K, N), dtype=dtype, device=device) - out = torch.zeros((M, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=torch.float32, device=device) @triton.jit - def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, M: tl.constexpr, N: tl.constexpr, + K: tl.constexpr): M_offsets = tl.arange(0, M) N_offsets = tl.arange(0, N) @@ -3949,13 +4182,20 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) - pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), - out.numel(), M=M, N=N, K=K) + _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], M=M, N=N, K=K) + if is_cpu() and (dtype == torch.float16 or dtype == torch.bfloat16): + # torch.matmul not implemented for Half float (float16) cpu + reference_out = torch.tensor(np.matmul(to_numpy(in1), to_numpy(in2))).to(torch.float32) + # f32_in1 = convert_float_to_float32(in1) + # f32_in2 = convert_float_to_float32(in2) + # reference_out = torch.matmul(f32_in1, f32_in2) + else: + reference_out = torch.matmul(in1, in2).to(torch.float32) - reference_out = torch.matmul(in1, in2) torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) def test_load_cache_modifier(cache, device): @@ -3978,14 +4218,14 @@ def _kernel(dst, src, CACHE: tl.constexpr): amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] - flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line] if cache == '' or cache == '.ca': - assert cg_cache_modifier_str not in global_load_line[0] + assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0]) if cache == '.cg': assert cg_cache_modifier_str in global_load_line[0] if cache == '.cv': - assert cv_cache_modifier_str in flat_load_line[0] + assert cv_cache_modifier_str in global_load_line[0] if is_cuda(): ptx = pgm.asm['ptx'] @@ -4000,6 +4240,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert 'ld.global.cg' not in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("N", [16, 10, 11, 1024]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4027,6 +4268,7 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("has_hints", [False, True]) def test_vectorization_hints(has_hints, device): @@ -4054,6 +4296,7 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): assert "ld.global.v4.b32" not in ptx +@pytest.mark.cpu @pytest.mark.interpreter def test_assume(device): @@ -4080,6 +4323,7 @@ def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) def test_store_cache_modifier(cache, device): @@ -4144,6 +4388,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert 'st.global.wt' in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) def test_store_eviction_policy(eviction_policy, device): @@ -4182,6 +4427,7 @@ def _impl(value=10): return value +@pytest.mark.cpu @pytest.mark.interpreter def test_default(device): value = 5 @@ -4207,6 +4453,7 @@ def _kernel(ret0, ret1, value=3): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_noop(device): @@ -4234,6 +4481,7 @@ def kernel(x): kernel[(1, )](x) +@pytest.mark.cpu @pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) @@ -4257,6 +4505,7 @@ def kernel(VALUE, X): # -------------------- +@pytest.mark.cpu @pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: @@ -4278,6 +4527,7 @@ def kernel(VALUE, X): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -4315,6 +4565,7 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -4328,6 +4579,7 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4345,6 +4597,7 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): @@ -4372,6 +4625,7 @@ def generate_kernel(shape_x, shape_z): np.testing.assert_equal(z, to_numpy(z_tri)) +@pytest.mark.cpu def test_reshape_err(device): @triton.jit @@ -4385,6 +4639,7 @@ def kernel(): assert "reshape" in str(exc_info.value) +@pytest.mark.cpu def test_trans_reshape(device): @triton.jit @@ -4411,8 +4666,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) k = kernel[(1, )](input, actual, shape[0], shape[1]) - assert k.asm['ttgir'].count( - 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + if not is_cpu(): + assert k.asm['ttgir'].count( + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) @@ -4446,6 +4702,7 @@ def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): tl.store(ptr + offsets, vec, mask=mask) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("type", ["inline", "noinline"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4477,6 +4734,7 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("if_type", [ "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", @@ -4537,6 +4795,7 @@ def _kernel(dst): _kernel[(1, )](dst=dst, num_warps=4) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) def test_unary_math(func_str, device): @@ -4770,6 +5029,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) def test_for_iv(lo, hi, iv, device): @@ -4789,6 +5049,7 @@ def kernel(Out, lo, hi, iv: tl.constexpr): assert out[0] == sum(range(lo, hi, iv)) +@pytest.mark.cpu @pytest.mark.interpreter def test_if_else(device): @@ -4814,6 +5075,7 @@ def kernel(Cond, TrueVal, FalseVal, Out): assert to_numpy(out)[0] == false_val[0] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["dynamic", "static"]) def test_if_return(mode, device): @@ -4873,6 +5135,7 @@ def add_fn_static_cond(x, cond: tl.constexpr): return x + 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "call_type", @@ -4942,6 +5205,7 @@ def kernel(Out, call_type: tl.constexpr): assert to_numpy(out)[0] == 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("_cond1", [True, False]) @pytest.mark.parametrize("_cond2", [True, False]) @@ -4984,34 +5248,38 @@ def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): assert out[0] == targets[(_cond1, _cond2, _cond3)] +@pytest.mark.cpu @pytest.mark.interpreter def test_while(device): @triton.jit - def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): - init_i = tl.load(InitI) + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ, BLOCKSIZE: tl.constexpr): + range = tl.arange(0, BLOCKSIZE) + init_i = tl.load(InitI + range) curr_i = init_i j = 0 # Check that init_i is not updated by the loop while j < tl.load(Bound): curr_i = curr_i + (j == tl.load(CutOff)) j += 1 - tl.store(OutInitI, init_i) - tl.store(OutI, curr_i) + tl.store(OutInitI + range, init_i) + tl.store(OutI + range, curr_i) tl.store(OutJ, j) - out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) - out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) - init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) - out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + size = 16 + out_i = to_triton(np.zeros((size, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((size, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((size, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((size, ), 0, dtype=np.int32), device=device) bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) - kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) - assert out_init_i[0] == init_i[0] - assert out_i[0] == init_i[0] + 1 + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j, size) + np.testing.assert_equal(to_numpy(out_init_i), to_numpy(init_i)) + np.testing.assert_equal(to_numpy(out_i), to_numpy(init_i + 1)) assert out_j[0] == bound[0] +@pytest.mark.cpu @pytest.mark.interpreter def test_nested_while(device): @@ -5029,6 +5297,7 @@ def nested_while(data, countPtr): assert data[0] == 40 +@pytest.mark.cpu def test_constexpr_if_return(device): # Reproducer for #4883, return statement in an if with a constexpr causes # errors when combined with non-trivial control flow graphs @@ -5072,7 +5341,9 @@ def kernel(Out): a = torch.empty((), device=device, dtype=torch.int32) h = kernel[(1, )](a) assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] - assert "poison" in h.asm["llir"], h.asm["llir"] + # xpu uses llvm.store, which in this case is removed by the optimizer + if not is_xpu(): + assert "poison" in h.asm["llir"], h.asm["llir"] # ----------------------- @@ -5080,6 +5351,7 @@ def kernel(Out): # ----------------------- +@pytest.mark.cpu def test_num_threads(device): if is_hip(): pytest.skip("test_num_threads is not supported in HIP") @@ -5138,6 +5410,9 @@ def kernel(Out): # TODO: backend should be tested separately layouts = [ + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -5146,6 +5421,14 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), ] @@ -5177,12 +5460,16 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("interm_layout", intermediate_layouts) -@pytest.mark.parametrize("dst_layout", layouts) -def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): if str(src_layout) == str(dst_layout): pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") if is_hip(): try: scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) @@ -5199,35 +5486,37 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): layouts = f""" #src = {src_layout} #dst = {dst_layout} + #smem = #ttg.shared_memory """ if interm_layout is None else f""" #src = {src_layout} #interm = {interm_layout} #dst = {dst_layout} + #smem = #ttg.shared_memory """ conversion = f""" - %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> - %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xi32, #src> - %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> - %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xf16, #src> + %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem> + %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src> + %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem> + %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src> - %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ ir = layouts + f""" - module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> @@ -5245,15 +5534,104 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) +layouts_3d = [ + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0, + k_width=1), +] + +shared_layout_3d = [ + SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), +] + + +@pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) +@pytest.mark.parametrize("shared_layout", shared_layout_3d) +@pytest.mark.parametrize("dist_layout", layouts_3d) +def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist> + %9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist> + %16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist> + %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist> + %31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist> + %38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + mma_pairs = [ [ MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), @@ -5301,7 +5679,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("mma_pair", mma_pairs) -def test_convertmma2mma(M, N, mma_pair, dtype, device): +def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): if is_hip(): pytest.skip("test_mma2mma is not supported in HIP") @@ -5312,44 +5690,33 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device): pytest.skip("Skip testing MMAv3 on devices with CC < 9") num_warps = np.cumprod(src_layout.warps_per_cta)[-1] - # TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA - warps_per_cta = src_layout.warps_per_cta - interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1], - [1, 1], [0, 1]) def do_test(src_layout, dst_layout): layouts = f""" #src = {src_layout} #dst = {dst_layout} - #interm = {interm} - """ - - conversion = f""" - %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> - %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ ir = layouts + f""" - module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32}} {{ tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> - %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> - %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #interm> - """ + conversion + f""" - %15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm> - %16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm> - %17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr, #interm>, tensor<{M}x{N}xi32, #interm> - tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr, #interm> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> tt.return }} }} @@ -5358,10 +5725,10 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) + temp_file = tmp_path / "test_convertmma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) assert torch.equal(z, x) @@ -5370,6 +5737,7 @@ def do_test(src_layout, dst_layout): do_test(mma_pair[1], mma_pair[0]) +@pytest.mark.cpu @pytest.mark.interpreter def test_load_scalar_with_mask(device): @@ -5388,6 +5756,7 @@ def kernel(Input, Index, Out, N: int): # This test is used to test our own PTX codegen for float16 and int16 conversions # maybe delete it later after ptxas has been fixed +@pytest.mark.cpu @pytest.mark.parametrize("dtype_str", ['float16', 'int16']) def test_ptx_cast(dtype_str, device): @@ -5433,6 +5802,9 @@ def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.co def f8_to_f16(x, dtype): + if is_cpu(): + assert (False and "Works as expected only for GPU") + @triton.jit def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) @@ -5457,19 +5829,18 @@ def matmul_kernel( # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # low_precision_acc: tl.constexpr, # - num_pipeline_stages: tl.constexpr = 3 # + num_stages: tl.constexpr = 3 # ): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): a = tl.load(a_ptrs) b = tl.load(b_ptrs) accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) @@ -5481,9 +5852,11 @@ def matmul_kernel( # tl.store(c_ptrs, accumulator) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) -@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), + (64, 64, 64)] if not is_cpu() else [(32, 32, 128), (32, 32, 32)]) @pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) @pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): @@ -5504,16 +5877,23 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s num_warps = 8 a = to_triton(A, device=device, dst_type=in_type_str) b = to_triton(B, device=device, dst_type=in_type_str) - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, - num_pipeline_stages=num_stages) + num_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) - th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) - th_b = f8_to_f16(torch_b, in_type_str) - ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if is_cpu() and 'float8' in in_type_str: + in_dtype = getattr(tl, in_type_str) + th_a = convert_float_to_float32(torch_a, in_dtype) + th_b = convert_float_to_float32(torch_b, in_dtype) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + else: + th_a = f8_to_f16(torch_a, in_type_str) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) else: @@ -5527,6 +5907,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("enable_fp_fusion", [False, True]) @pytest.mark.parametrize("default_override", [False, True]) def test_enable_fp_fusion(enable_fp_fusion, default_override, device): @@ -5548,10 +5929,9 @@ def mul_add(data): else: h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) - if not is_cuda(): - return - found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None - assert found_fma == enable_fp_fusion + if is_cuda(): + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion # ----------------------- @@ -5559,6 +5939,7 @@ def mul_add(data): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("dtype", ['float16', 'float32']) @pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) @pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) @@ -5597,6 +5978,7 @@ def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['float16', 'float32']) def test_clamp(dtype, device): @@ -5633,6 +6015,7 @@ def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexp # Test for symmetric clamp(x, -limit, limit), as it may go through optimized # codegen in the backends +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['float16', 'float32']) def test_clamp_symmetric(dtype, device): @@ -5668,6 +6051,7 @@ def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_static_range(device): @@ -5688,20 +6072,28 @@ def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): assert (Out == Acc).all(), (Out, Acc) +@pytest.mark.cpu @pytest.mark.interpreter def test_tl_range(device): if is_hip(): pytest.skip("test_tl_range is not supported in HIP") M, N, K = 64, 64, 512 - BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + if is_cpu(): + block_m, block_n, block_k = 32, 32, 64 + else: + block_m, block_n, block_k = M, N, 64 + BLOCK_M, BLOCK_N, BLOCK_K = block_m, block_n, block_k a = torch.randn((M, K), device=device, dtype=torch.float16) b = torch.randn((K, N), device=device, dtype=torch.float16) c = torch.empty((M, N), dtype=torch.float32, device=device) - pgm = matmul_kernel[ - 1, - ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K, 0, num_pipeline_stages=5) - ref_out = torch.matmul(a, b).to(torch.float32) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + pgm = matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), + c.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, 0, num_stages=5) + if is_cpu(): + # torch.matmul not implemented for Half float (float16) cpu + ref_out = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=torch.float32, device=device) + else: + ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. # Thus we use a higher tolerance @@ -5726,8 +6118,8 @@ def maxnreg_noinline2(X): tl.store(X, 0) +@pytest.mark.interpreter def test_maxnreg(device): - assert not is_interpreter(), "this test won't work with the interpreter" if not is_cuda(): pytest.skip('maxnreg only works on CUDA') @@ -5741,16 +6133,18 @@ def kernel(X): X = torch.empty(1, dtype=torch.int32, device=device) k = kernel[(1, )](X, maxnreg=42) - # Ensure that .maxnreg is set on the kernel function (marked with .entry) - # and not on either of the noinline functions (marked with .func). - try: - assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) - assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) - except AssertionError: - print("Failing ptx:\n", k.asm["ptx"]) - raise + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise +@pytest.mark.cpu @pytest.mark.interpreter def test_temp_var_in_loop(device): @@ -5785,6 +6179,7 @@ def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): assert (acc == out).all() +@pytest.mark.cpu @pytest.mark.interpreter def test_num_programs(device): # Assuming that the kernel is launched with a grid of (11, 21, 31) @@ -5896,6 +6291,30 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): torch.testing.assert_close(Z, X.sum().to(torch.int32)) +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + def test_side_effectful_scan(device): if device != "cuda": pytest.skip() @@ -5914,3 +6333,73 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): Z = torch.zeros_like(X) sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) + + +# stress test slice layout usages in reductions. +@pytest.mark.parametrize("in_shape, perm, red_dims", [ + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), +]) +def test_chained_reductions(in_shape, perm, red_dims, device): + + @triton.jit + def kernel(In, Out, # + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) + vals = tl.load(In + idx) + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) + tl.store(Out + st_idx, r) + + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) + temp = torch.permute(input, perm).contiguous() + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) + result = torch.empty_like(ref) + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) + + assert torch.all(ref == result) + + +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis): + + @triton.jit + def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], + src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + + return output + + src = torch.randn(src_shape, device='cuda') + indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) diff --git a/python/test/unit/language/test_libdevice.py b/python/test/unit/language/test_libdevice.py new file mode 100644 index 000000000000..2573aef5c5aa --- /dev/null +++ b/python/test/unit/language/test_libdevice.py @@ -0,0 +1,57 @@ +import pytest +import torch + +import triton +import triton.language as tl + +from triton.language.extra import libdevice +from triton.language.extra.libdevice import fast_dividef as my_fast_dividef + + +@pytest.mark.parametrize("dtype_str", ["float32", "float64"]) +@pytest.mark.parametrize( + "libdevice_fn, torch_special_fn", + [ + ("j0", "bessel_j0"), + ("j1", "bessel_j1"), + ("y0", "bessel_y0"), + ("y1", "bessel_y1"), + ("cyl_bessel_i0", "i0"), + ("cyl_bessel_i1", "i1"), + ], +) +def test_bessel(dtype_str, libdevice_fn, torch_special_fn, device): + SIZE = 128 + dtype = getattr(torch, dtype_str) + + x = torch.randn((SIZE, ), dtype=dtype, device=device) + y_exp = torch.empty((SIZE, ), dtype=dtype, device=device) + y_ref = getattr(torch.special, torch_special_fn)(x) + + @triton.jit + def kernel(in_p, out_p, fn: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(in_p + off) + res = getattr(libdevice, fn)(x) + tl.store(out_p + off, res) + + kernel[(1, )](x, y_exp, fn=libdevice_fn, SIZE=SIZE, num_warps=4, num_ctas=1) + + torch.testing.assert_close(y_ref, y_exp, equal_nan=True) + + +def test_libdevice_rename(device): + # mark the import as used by this test + _ = my_fast_dividef + + @triton.jit + def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.load(in_ptr + offsets) + tl.store(out_ptr + offsets, data) + + BLOCK_SIZE = 256 + inp = torch.randn(BLOCK_SIZE, device=device) + out = torch.empty_like(inp) + + triton_copy[(1, )](inp, out, BLOCK_SIZE) diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index fa5f34290b49..6b279ab106d4 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -5,23 +5,9 @@ import triton import triton.language as tl import triton.tools.experimental_descriptor +from test_core import is_cpu - -def is_cuda(): - return triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def is_hopper(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 - - -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - - -def is_hip_mi200(): - target = triton.runtime.driver.active.get_current_target() - return target.backend == 'hip' and target.arch == 'gfx90a' +from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200 def check_capabilities(): @@ -175,17 +161,17 @@ def mxfp_to_bf16_kernel( x_bf16 = x_f8.to(tl.bfloat16) else: # e2m1 - em0 = x & 0x70 - em1 = x & 0x7 - x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) - x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + em0 = x & 0x7 + em1 = x & 0x70 + x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4) + x1 = (em1.to(tl.uint16) << (2)) | ((x & 0x80).to(tl.uint16) << (8)) # Three cases: # 1) x is normal and non-zero: Correct bias - x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) - x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1) # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 - x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) - x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1) # 3) x is zero, do nothing x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) # Multiplication preserves infs and NaNs in x_bf16 @@ -229,8 +215,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.parametrize("scale", [True, False]) def test_pipeline_matmul(scale, device): check_capabilities() - if scale and not is_cuda(): - pytest.skip("NYI: scale_dot just implemented in CUDA") + if scale and not (is_cuda() or is_hip_cdna()): + pytest.skip("NYI: scale_dot just implemented in CUDA/HIP") M, N, K = 512, 512, 128 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 NUM_STAGES = 4 @@ -282,7 +268,11 @@ def test_pipeline_matmul(scale, device): if scale: ref_out = dot_scale_ref(a, scale_a, b, a_type, b_type) else: - ref_out = torch.matmul(a, b) + if is_cpu(): + ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16) + else: + ref_out = torch.matmul(a, b) + # Bigger tolerance for AMD MI200 devices. # MI200 devices use reduced precision fp16 and bf16 and flush input and # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices @@ -292,26 +282,26 @@ def test_pipeline_matmul(scale, device): if is_cuda(): ttgir = handler.asm["ttgir"] if use_tma: - assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found" + assert ttgir.count("ttng.async_tma_copy_global_to_local") != 0, "async tma copy not found" assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" # a_tma, b_tma, output_tma, barriar - assert ttgir.count("triton_gpu.local_alloc") == 4, "alloc number not match" - assert ttgir.count("triton_nvidia_gpu.barrier_expect") != 0, "barrier_expect not found" - assert ttgir.count("triton_nvidia_gpu.wait_barrier") != 0, "wait_barrier not found" - assert ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + assert ttgir.count("ttg.local_alloc") == 4, "alloc number not match" + assert ttgir.count("ttng.barrier_expect") != 0, "barrier_expect not found" + assert ttgir.count("ttng.wait_barrier") != 0, "wait_barrier not found" + assert ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" else: # 1. check async - assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + assert ttgir.count("ttg.async_copy_global_to_local") != 0, "async copy not found" # 2. check number of stages assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" # 3. check alloc - assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + assert ttgir.count("ttg.local_alloc") == 2, "alloc number not match" # 4. check dot cc = torch.cuda.get_device_capability() if cc[0] >= 9: - ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" else: - ttgir.count("triton_gpu.dot") != 0, "dot not found" + ttgir.count("ttg.dot") != 0, "dot not found" def test_pipeline_vecadd(device): @@ -330,11 +320,11 @@ def test_pipeline_vecadd(device): if is_cuda(): ttgir = handler.asm["ttgir"] # 1. check async - assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + assert ttgir.count("ttg.async_copy_global_to_local") != 0, "async copy not found" # 2. check number of stages assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" # 3. check alloc - assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + assert ttgir.count("ttg.local_alloc") == 2, "alloc number not match" @pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3]) diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index b3392d4750c4..2938773867cc 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -3,7 +3,7 @@ import torch import triton.language as tl -from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu # --------------- # test maximum/minimum ops @@ -26,7 +26,8 @@ def test_maximum_minium(dtype, op, device): @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("descending", [False, True]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_sort(M, N, descending, dtype_str, device): @@ -54,7 +55,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_flip(M, N, dtype_str, device): diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 193895757d32..de7521446044 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -6,6 +6,8 @@ import pytest +import triton + dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] @@ -15,12 +17,18 @@ def is_interpreter(): return os.environ.get('TRITON_INTERPRET', '0') == '1' +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + # TODO: Print with multiple operands +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("func_type, data_type", [(fn, data_type) - for fn in ["device_print", "device_print_scalar"] + for fn in ["device_print", "device_print_scalars"] for data_type in torch_types] + [ ("print", "int32"), ("static_print", "int32"), @@ -37,6 +45,9 @@ def is_interpreter(): ("device_print_uint", "uint32"), ]) def test_print(func_type: str, data_type: str, device: str): + if is_cpu() and (data_type == "float16" or func_type in ["device_print_pointer", "device_print_large"]): + pytest.skip("test_print for float16/pointer/large are not yet supported on CPU.") + proc = subprocess.run( [sys.executable, print_path, "test_print", func_type, data_type, device], capture_output=True, @@ -56,6 +67,11 @@ def test_print(func_type: str, data_type: str, device: str): # Constant for testing the printing of scalar values SCALAR_VAL = 42 + # TODO: Consider cases for signedness, overflow, and multiple pids (non-determinism). + if is_cpu(): + _check_cpu_print(proc.stdout.decode("UTF-8"), func_type, data_type, N, SCALAR_VAL) + return + # Format is # pid (, , ) idx (, , ...) (operand ) expected_lines = Counter() @@ -66,11 +82,15 @@ def test_print(func_type: str, data_type: str, device: str): if data_type.startswith("float"): line += ".000000" expected_lines[line] = 1 - elif func_type == "device_print_scalar": + elif func_type == "device_print_scalars": line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}" if data_type.startswith("float"): line += ".000000" expected_lines[line] = N + line = f"pid (0, 0, 0) idx () int: {SCALAR_VAL}" + expected_lines[line] = N + line = "pid (0, 0, 0) idx () float: 3.140000" + expected_lines[line] = N elif func_type == "device_print_negative": for i in range(N): line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}" @@ -102,8 +122,11 @@ def test_print(func_type: str, data_type: str, device: str): for i in range(N): expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + cpu_gpu_msg = "Both CPU and GPU backends are available. Using the GPU backend." actual_lines = Counter() for line in outs: + if line == cpu_gpu_msg: + continue # Trim the exact pointer address in the output--they can change per run. line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line actual_lines[line] += 1 @@ -115,3 +138,90 @@ def test_print(func_type: str, data_type: str, device: str): continue print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') assert all(delta == 0 for delta in diff.values()) + + +def _check_cpu_print(actual, func_type, data_type, N, SCALAR_VAL): + # An example of a tensor printing is like: + # (0, 0, 0) x: [ 0, 1, 2, 3, 4, 5, 6, 7, + # 8, 9, 10, 11, 12, 13, 14, 15, + # ... + # 120, 121, 122, 123, 124, 125, 126, 127] + PID_PREFIX = "(0, 0, 0)" + NEWLINE_WITH_PADDING = "\n" + " " * (len(PID_PREFIX + " x: [")) + if func_type in ("print", "device_print", "device_print_uint"): + expected = PID_PREFIX + " x: [" + for i in range(N): + offset = (1 << 31) if data_type == "uint32" else 0 + expected += f"{i + offset:3}" + if data_type.startswith("float"): + expected += ".0000" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "device_print_scalars": + expected = f"{PID_PREFIX} x: {SCALAR_VAL}" + if data_type.startswith("float"): + expected += ".000000" + expected += f"\n{PID_PREFIX} int: {SCALAR_VAL}" + expected += f"\n{PID_PREFIX} float: 3.140000" + elif func_type == "device_print_negative": + expected = PID_PREFIX + " x: [" + for i in range(N): + expected += f"{-i:4}" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "device_print_hex": + expected = PID_PREFIX + " x: [" + for i in range(N): + if data_type.endswith("8"): + expected += f"0x{i:02x}" + elif data_type.endswith("16"): + expected += f"0x{i:04x}" + elif data_type.endswith("32"): + expected += f"0x{i:08x}" + elif data_type.endswith("64"): + expected += f"0x{i:016x}" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "static_print": + expected = f" int32[constexpr[{N}]]" + elif func_type == "no_arg_print": + expected = f"{PID_PREFIX}: 0" + elif func_type == "print_no_arg": + expected = f"{PID_PREFIX} no arg" + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + expected = "" + for k in range(2): + expected += PID_PREFIX + ": [" + for i in range(N): + expected += f"{i:3}" if k == 0 else "1" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += "\n" + " " * (len(PID_PREFIX + ": [")) + else: + expected += " " + expected += "]" + if k == 0: + expected += "\n" + + # Ignore the trailing new line. + assert actual[:-1] == expected diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index 456ebf113792..149520f14b20 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -4,6 +4,8 @@ import triton.language as tl import pytest +from triton._internal_testing import is_cuda + def do_bench(kernel_call, quantiles): return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) @@ -11,6 +13,10 @@ def do_bench(kernel_call, quantiles): @pytest.mark.parametrize('use_cuda_graph', [False, True]) def test_kwargs(use_cuda_graph: bool, device: str): + + if not is_cuda() and use_cuda_graph: + pytest.skip("Use cuda graph without cuda looks strange") + M, N = 1024, 16 src = torch.randn(M * N, device=device) dst = torch.empty(M * N, device=device) @@ -32,7 +38,8 @@ def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLO assert len(_kernel.cache) == 2 -def test_restore(device): +@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True]) +def test_restore(pass_kwargs_to_kernel, device): N = 1024 src = torch.zeros(N, device=device) @@ -46,7 +53,10 @@ def _kernel(src, N, BLOCK_SIZE: tl.constexpr): tl.store(src + offsets, x, mask=offsets < N) grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) - _kernel[grid](src, N) + if pass_kwargs_to_kernel: + _kernel[grid](src=src, N=N) + else: + _kernel[grid](src, N) triton.testing.assert_close(src, torch.ones_like(src)) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a45cb3f888ca..86faff6674d1 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,14 +1,14 @@ import importlib.util import itertools import shutil -import tempfile +import pathlib import pytest import torch import triton import triton.language as tl -from triton.runtime.jit import JITFunction +from triton.runtime.jit import JITFunction, get_device_key from triton._internal_testing import is_hip @@ -129,17 +129,15 @@ def test_combine_fn_change(): seen_keys.add(key) -def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: - f.write(('# extra line\n' * num_extra_lines) + code) - f.flush() - spec = importlib.util.spec_from_file_location("module.name", f.name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) +def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) return module -def test_changed_line_numbers_invalidate_cache(): +def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path): from textwrap import dedent code = dedent(""" import triton @@ -147,10 +145,12 @@ def test_changed_line_numbers_invalidate_cache(): def test_kernel(i): i = i + 1 """) - orig_mod = write_and_load_module(code, 0) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) orig_cache_key = orig_mod.test_kernel.cache_key - updated_mod = write_and_load_module(code, 1) + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) updated_cache_key = updated_mod.test_kernel.cache_key assert orig_cache_key != updated_cache_key @@ -194,12 +194,12 @@ def kernel(X, i: tl.int32): x = torch.empty(1, dtype=torch.int32, device=device) - device = getattr(torch, device).current_device() + device_key = get_device_key() kernel[(1, )](x, 1) kernel[(1, )](x, 8) kernel[(1, )](x, 16) kernel[(1, )](x, 17) - assert len(kernel.cache[device]) == 3 + assert len(kernel.cache[device_key]) == 3 GLOBAL_DEFAULT_ARG = 1 @@ -222,7 +222,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): kernel[(1, )](x) assert x == torch.ones_like(x) - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel.cache[device]) == 1 @@ -415,7 +415,7 @@ def kernel_add(a, b, o, N: tl.constexpr): torch.randn(32, dtype=torch.float32, device=device), 32, ] - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel_add.cache[device]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 @@ -431,6 +431,9 @@ def test_jit_debug(device) -> None: def kernel(tmp): tl.device_assert(tl.load(tmp) == 1, "tmp == 1") + if device == "cpu": + pytest.skip('Device Assert is not yet supported on CPU') + device = getattr(torch, device).current_device() tmp = torch.tensor([1], dtype=torch.int32, device=device) assert len(kernel.cache[device]) == 0 @@ -454,7 +457,7 @@ def test_jit_noinline(device) -> None: def kernel_add_device(a, b, o, N: tl.constexpr): add_fn(a, b, o, N) - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel_add_device.cache[device]) == 0 kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 @@ -498,7 +501,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): tl.device_assert(idx < 32, "idx < 32") tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) - device = getattr(torch, device).current_device() + device = get_device_key() # get the serialized specialization data specialization_data = None @@ -572,7 +575,8 @@ def compiled_hook(*args, **kwargs): kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) assert specialization_data is not None and specialization_data_compiled == specialization_data assert is_warmup is True - assert key in kernel_add.cache[getattr(torch, device).current_device()] + device_key = get_device_key() + assert key in kernel_add.cache[device_key] @pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) diff --git a/python/test/unit/runtime/test_driver.py b/python/test/unit/runtime/test_driver.py index de00082f52c0..9bd51cc2b8cb 100644 --- a/python/test/unit/runtime/test_driver.py +++ b/python/test/unit/runtime/test_driver.py @@ -1,6 +1,9 @@ import sys +from concurrent.futures import ThreadPoolExecutor +import torch import triton +import triton.language as tl def test_is_lazy(): @@ -12,3 +15,27 @@ def test_is_lazy(): assert triton.runtime.driver.active._obj is None utils = triton.runtime.driver.active.utils # noqa: F841 assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) + + +def test_kernel_in_thread(device): + # Test calling in a new thread sets a valid device context + buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device) + + @triton.jit + def _kernel(P, BLOCK: tl.constexpr): + pid = tl.program_id(0).to(tl.int64) + offset = pid * BLOCK + tl.arange(0, BLOCK) + + p = tl.load(P + offset) + tl.store(P + offset, p) + + def call_triton(): + N = buf.numel() + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), ) + _kernel[grid](buf, BLOCK=1024) + getattr(torch, device).synchronize() + + call_triton() + with ThreadPoolExecutor(1) as pool: + future = pool.submit(call_triton) + future.result() diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 02777923303a..334d5d635f67 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -7,6 +7,7 @@ from triton.compiler import ASTSource target = triton.runtime.driver.active.get_current_target() +start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn' def compile_fn(attrs): @@ -27,8 +28,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: config = AttrsDescriptor.from_hints({i: 16 for i in range(4)}) - multiprocessing.set_start_method('fork') - proc = multiprocessing.Process(target=compile_fn, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -49,8 +50,8 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: config = AttrsDescriptor.from_hints({0: 16}) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn_dot, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 @@ -92,8 +93,8 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: # stage 2.p shutil.rmtree(fresh_triton_cache) - assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, )) # stage 3.c proc.start() diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 05bf1fe4940e..8ea6212020ec 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -1,30 +1,49 @@ -import os import pytest import torch import triton.language as tl import triton -@pytest.mark.parametrize('cond, opt_flag, env_var', [ - (cond, opt_flag, env_var) for cond in [True, False] \ - for opt_flag in [True, False] \ - for env_var in [True, False]\ -]) + +@pytest.mark.parametrize('cond', [True, False]) +@pytest.mark.parametrize('opt_flag', [True, False, None]) +@pytest.mark.parametrize('env_var', [True, False]) +@pytest.mark.parametrize('jit_flag', [True, False]) @pytest.mark.forked -def test_device_assert(cond, opt_flag, env_var, device): - os.environ['TRITON_DEBUG'] = str(int(env_var)) +def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device): + monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) torch.zeros([1], dtype=torch.int32, device=device) - @triton.jit + @triton.jit(debug=jit_flag) def _kernel(COND: tl.constexpr): tl.device_assert(COND, 'test') - if not cond and (opt_flag or env_var): + is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) + + kwargs = {} + if opt_flag is not None: + kwargs["debug"] = opt_flag + + if not cond and is_debug: with pytest.raises(RuntimeError): - _kernel[(1, )](cond, debug=opt_flag) + _kernel[(1, )](cond, **kwargs) getattr(torch, device).synchronize() return - _kernel[(1, )](cond, debug=opt_flag) + _kernel[(1, )](cond, **kwargs) + getattr(torch, device).synchronize() + + +def test_device_assert_barrier(monkeypatch, device): + monkeypatch.setenv("TRITON_DEBUG", "1") + tensor = torch.zeros([16], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(in_ptr0): + xindex = tl.arange(0, 8) + tmp0 = tl.load(in_ptr0 + xindex) + tl.device_assert(tmp0 < 1) + + _kernel[(1, )](tensor) getattr(torch, device).synchronize() diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 54072a829c83..6646d94f50a8 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -162,16 +162,6 @@ def kernel_pipe_error(in_ptr, out_ptr): if tl.max(val) > 0: k += 1 - with enable_remark_context(): - triton.compile( - triton.compiler.ASTSource( - fn=kernel_pipe_error, - signature={"in_ptr": "*fp32", "out_ptr": "*fp32"}, - constants={}, - ), - options={"cluster_dims": (1, 1, 1)}, - ) - - _, err = capfd.readouterr() - - assert "operation scheduled before its operands" in err, "expect swp op remark" + i = torch.empty(64 * 64, dtype=torch.float32).cuda() + o = torch.empty(64 * 64, dtype=torch.float32).cuda() + kernel_pipe_error[(1, )](i, o) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 935c495fb333..d80c79cf61cb 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -426,7 +426,7 @@ def test_compile_link_autotune_matmul(): def test_ttgir_to_ptx(): src = """ -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { tt.return } diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py new file mode 100644 index 000000000000..48ed90d8b217 --- /dev/null +++ b/python/test/unit/tools/test_irsource.py @@ -0,0 +1,92 @@ +import pathlib +import triton +from triton.compiler import IRSource, make_backend +from triton._C.libtriton import ir + +target = triton.runtime.driver.active.get_current_target() +backend = make_backend(target) + + +def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: + ''' + Tests that MLIR attributes are parsed correctly from input ttir/ttgir. + + Checks for the following: + 1. Name and type signature are parsed correctly + 2. _get_num_warps_from_ir_str() works + 3. tt.nv_tma_desc attribute is parsed correctly + ''' + + sample_ttgir = r""" +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32}, + %desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { + tt.return + } +} +""" + temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir" + temp_file.write_text(sample_ttgir) + context = ir.context() + src = IRSource(str(temp_file), context, backend) + + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" + + # check num warps + assert src.parse_options()['num_warps'] == 8 + + sample_ttgir_vector_add = r""" + #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) + attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.addi %9, %12 : tensor<1024xi32, #blocked> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } + } + """ + temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir" + temp_file.write_text(sample_ttgir_vector_add) + context = ir.context() + src = IRSource(str(temp_file), context, backend) + + # now test compilation + triton.compile(str(temp_file), target=target) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 031c58fb16ac..1c7058567a40 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,5 +1,5 @@ """isort:skip_file""" -__version__ = '3.0.0' +__version__ = '3.2.0' # --------------------------------------- # Note: import order is significant here. @@ -20,6 +20,7 @@ from .runtime.jit import jit from .compiler import compile, CompilationError from .errors import TritonError +from .runtime._allocation import set_allocator from . import language from . import testing @@ -44,6 +45,7 @@ "OutOfResources", "reinterpret", "runtime", + "set_allocator", "TensorWrapper", "TritonError", "testing", diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index f8909f7c0587..6d797efcc7ef 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -4,6 +4,7 @@ import torch import triton import triton.language as tl +from triton.backends.nvidia.compiler import _path_to_binary import pytest from numpy.random import RandomState @@ -14,6 +15,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] integral_dtypes = int_dtypes + uint_dtypes float_dtypes = ['float16', 'float32', 'float64'] +float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16'] dtypes = integral_dtypes + float_dtypes dtypes_with_bfloat16 = dtypes + ['bfloat16'] torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] @@ -35,16 +37,44 @@ def is_cuda(): return False if target is None else target.backend == "cuda" +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + def is_hip(): target = get_current_target() return False if target is None else target.backend == "hip" +def is_hip_mi200(): + target = get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def is_hip_mi300(): + target = get_current_target() + return target.backend == 'hip' and target.arch in ('gfx940', 'gfx941', 'gfx942') + + +def is_hip_cdna(): + return is_hip_mi200() or is_hip_mi300() + + +def is_xpu(): + target = get_current_target() + return False if target is None else target.backend == "xpu" + + def get_arch(): target = get_current_target() return "" if target is None else str(target.arch) +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same @@ -116,8 +146,21 @@ def to_numpy(x): raise ValueError(f"Not a triton-compatible tensor: {x}") -def supports_tma(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 +def supports_tma(byval_only=False): + if not is_cuda(): + return False + _, cuda_version = _path_to_binary("ptxas") + min_cuda_version = (12, 0) if byval_only else (12, 3) + cuda_version_tuple = tuple(map(int, cuda_version.split("."))) + assert len(cuda_version_tuple) == 2, cuda_version_tuple + return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version + + +def tma_skip_msg(byval_only=False): + if byval_only: + return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)" + else: + return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)" -requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)") +requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg()) diff --git a/python/triton/_utils.py b/python/triton/_utils.py new file mode 100644 index 000000000000..ca60c8c3cbca --- /dev/null +++ b/python/triton/_utils.py @@ -0,0 +1,22 @@ +from typing import Tuple, List, Any + +# Poor man's PyTree + + +def list_list_flatten(x: List[List[Any]]) -> Tuple[List[int], List[Any]]: + spec = [] + flat = [] + for l in x: + spec.append(len(l)) + flat.extend(l) + return spec, flat + + +def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: + ret = [] + idx = 0 + for size in spec: + ret.append(flat[idx:idx + size]) + idx += size + assert idx == len(flat) + return ret diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..738ea2fef8bc 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -28,6 +28,7 @@ def _find_concrete_subclasses(module, base_class): @dataclass(frozen=True) class Backend: + name: str = "" compiler: BaseBackend = None driver: DriverBase = None @@ -42,7 +43,7 @@ def _discover_backends(): continue compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) driver = _load_module(name, os.path.join(root, name, 'driver.py')) - backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + backends[name] = Backend(name, _find_concrete_subclasses(compiler, BaseBackend), _find_concrete_subclasses(driver, DriverBase)) return backends diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index f2ba8eac807f..93ff051c5f39 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -2,10 +2,11 @@ import re import hashlib import subprocess +import sysconfig -from abc import ABCMeta, abstractmethod, abstractclassmethod -from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union, Set from types import ModuleType # Table that associates strings to AttrsDescriptor (sub)classes. @@ -22,7 +23,7 @@ def register_descriptor(cls): return cls -@register_descriptor +@dataclass class AttrsDescriptor: """ This class handles compile-time properties for specific function parameters. @@ -51,7 +52,10 @@ class AttrsDescriptor: `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant """ - __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + #__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + arg_properties: Dict = field(default_factory=dict) + property_values: Dict = field(default_factory=dict) + constant_properties: Set = field(default_factory=set) def __init__(self, params=None, values=None): """ @@ -210,6 +214,9 @@ def get_property_key(val, align): return "1" return "N" + def __repr__(self): + return f"AttrsDescriptor.from_dict({self.to_dict()!r})" + @dataclass(frozen=True) class GPUTarget(object): @@ -228,22 +235,23 @@ def __init__(self, target: GPUTarget) -> None: @staticmethod def _path_to_binary(binary: str): + binary += sysconfig.get_config_var("EXE") base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary), ] - for p in paths: - bin = p.split(" ")[0] - if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + for path in paths: + if os.path.exists(path) and os.path.isfile(path): + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: - return p, version.group(1) + return path, version.group(1) raise RuntimeError(f"Cannot find {binary}") - @abstractclassmethod + @classmethod + @abstractmethod def supports_target(target: GPUTarget): raise NotImplementedError diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index bbe8c047c6d1..a05efd7e0807 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,7 @@ -from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict from .errors import CompilationError -__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] +__all__ = [ + "compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", + "LazyDict" +] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 9ab3f4bc0c79..1c39d778ec0f 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1,7 +1,6 @@ import ast import inspect import re -import sys import warnings import os import textwrap @@ -15,6 +14,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from triton._utils import list_list_flatten, list_list_unflatten def mangle_ty(ty): @@ -218,7 +218,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n module_name = getattr(v, "__module__", "") if module_name in module_map: - self.gscope[k] = getattr(module_map[module_name], k) + self.gscope[k] = getattr(module_map[module_name], v.__name__) else: self.gscope[k] = v @@ -584,21 +584,17 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): # update block arguments names = [] - ret_types = [] - ir_ret_types = [] # variables in livein whose value is updated in `if` for name in liveins: # check type for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: if name in defs: - assert defs[name].type == liveins[name].type, \ - f'initial value for `{name}` is of type {liveins[name].type}, '\ - f'but the {block_name} block redefines it as {defs[name].type}' + type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721 + assert type_equal and defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name]}, '\ + f'but the {block_name} block redefines it as {defs[name]}' if name in then_defs or name in else_defs: names.append(name) - ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) - ir_ret_types.append(then_defs[name].handle.get_type() if name in - then_defs else else_defs[name].handle.get_type()) # variable defined in then but not in else if name in then_defs and name not in else_defs: else_defs[name] = liveins[name] @@ -610,16 +606,17 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): for name in sorted(then_defs.keys() & else_defs.keys()): if name in names: continue - then_ty = then_defs[name].type - else_ty = else_defs[name].type - assert then_ty == else_ty, \ + then_val = then_defs[name] + then_ty = then_val.type + else_val = else_defs[name] + else_ty = else_val.type + type_equal = type(then_val) == type(else_val) # noqa: E721 + assert type_equal and then_ty == else_ty, \ f'Mismatched type for {name} between then block ({then_ty}) '\ f'and else block ({else_ty})' names.append(name) - ret_types.append(then_ty) - ir_ret_types.append(then_defs[name].handle.get_type()) - return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + return then_defs, else_defs, then_block, else_block, names def visit_if_top_level(self, cond, node): with enter_sub_region(self) as sr: @@ -630,27 +627,35 @@ def visit_if_top_level(self, cond, node): self.builder.set_insertion_point_to_end(ip_block) self.builder.create_cond_branch(cond.handle, then_block, else_block) # visit then and else blocks - then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + then_defs, else_defs, then_block, else_block, names = \ self.visit_then_else_blocks(node, liveins, then_block, else_block) # create basic-block after conditional endif_block = self.builder.create_block() # then terminator self.builder.set_insertion_point_to_end(then_block) assert not then_block.has_terminator(), f"{then_block}" - self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + then_handles = [then_defs[n]._flatten_ir() for n in names] + then_handles_spec, then_handles_flat = list_list_flatten(then_handles) + self.builder.create_branch(endif_block, then_handles_flat) # else terminator self.builder.set_insertion_point_to_end(else_block) assert not else_block.has_terminator(), f"{else_block}" - self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - for ty in ir_ret_types: + else_handles = [else_defs[n]._flatten_ir() for n in names] + _, else_handles_flat = list_list_flatten(else_handles) + self.builder.create_branch(endif_block, else_handles_flat) + for then_h, else_h in zip(then_handles_flat, else_handles_flat): + ty = then_h.get_type() + assert ty == else_h.get_type() endif_block.add_argument(ty) # change block self.builder.set_insertion_point_to_start(endif_block) # update value - for i, name in enumerate(names): - new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) - self.set_value(name, new_tensor) + res_handles_flat = [endif_block.arg(i) for i in range(len(then_handles_flat))] + res_handles = list_list_unflatten(then_handles_spec, res_handles_flat) + for name, handles in zip(names, res_handles): + new_value = then_defs[name]._unflatten_ir(handles) + self.set_value(name, new_value) # TODO: refactor def visit_if_scf(self, cond, node): @@ -659,26 +664,32 @@ def visit_if_scf(self, cond, node): ip, last_loc = self._get_insertion_point_and_loc() then_block = self.builder.create_block() else_block = self.builder.create_block() if node.orelse else None - then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + then_defs, else_defs, then_block, else_block, names = \ self.visit_then_else_blocks(node, liveins, then_block, else_block) # create if op + then_handles = [then_defs[n]._flatten_ir() for n in names] + then_handles_spec, then_handles_flat = list_list_flatten(then_handles) self._set_insertion_point_and_loc(ip, last_loc) - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + if_op = self.builder.create_if_op([h.get_type() for h in then_handles_flat], cond.handle, True) then_block.merge_block_before(if_op.get_then_block()) self.builder.set_insertion_point_to_end(if_op.get_then_block()) if len(names) > 0: - self.builder.create_yield_op([then_defs[n].handle for n in names]) + self.builder.create_yield_op(then_handles_flat) if not node.orelse: else_block = if_op.get_else_block() else: else_block.merge_block_before(if_op.get_else_block()) self.builder.set_insertion_point_to_end(if_op.get_else_block()) if len(names) > 0: - self.builder.create_yield_op([else_defs[n].handle for n in names]) + else_handles = [else_defs[n]._flatten_ir() for n in names] + _, else_handles_flat = list_list_flatten(else_handles) + self.builder.create_yield_op(else_handles_flat) # update values - for i, name in enumerate(names): - new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) - self.set_value(name, new_tensor) + res_handles_flat = [if_op.get_result(i) for i in range(len(then_handles_flat))] + res_handles = list_list_unflatten(then_handles_spec, res_handles_flat) + for name, handles in zip(names, res_handles): + new_value = then_defs[name]._unflatten_ir(handles) + self.set_value(name, new_value) def visit_If(self, node): cond = self.visit(node.test) @@ -804,7 +815,7 @@ def visit_UnaryOp(self, node): def _verify_loop_carried_variable(self, name, loop_val, live_val): assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' - assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type' assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ f'Loop-carried variable {name} has initial type {live_val.type} '\ f'but is re-assigned to {loop_val.type} in loop! '\ @@ -827,7 +838,6 @@ def visit_While(self, node): # collect loop-carried values names = [] - ret_types = [] init_args = [] for name in loop_defs: if name in liveins: @@ -838,32 +848,37 @@ def visit_While(self, node): # these are loop-carried values names.append(name) - ret_types.append(loop_val.type) init_args.append(live_val) + init_handles = [a._flatten_ir() for a in init_args] + init_handles_spec, init_handles_flat = list_list_flatten(init_handles) + init_tys_flat = [h.get_type() for h in init_handles_flat] self._set_insertion_point_and_loc(ip, last_loc) - while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], - [arg.handle for arg in init_args]) + while_op = self.builder.create_while_op(init_tys_flat, init_handles_flat) # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before(), - [ty.to_ir(self.builder) for ty in ret_types]) + before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys_flat) self.builder.set_insertion_point_to_start(before_block) - for i, name in enumerate(names): - self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) - self.local_defs[name] = self.lscope[name] + block_args_flat = [before_block.arg(i) for i in range(len(init_handles_flat))] + block_args = list_list_unflatten(init_handles_spec, block_args_flat) + for name, init_val, arg_handles in zip(names, init_args, block_args): + val = init_val._unflatten_ir(arg_handles) + self.lscope[name] = val + self.local_defs[name] = val cond = self.visit(node.test) self.builder.set_insertion_point_to_end(before_block) # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... - self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + self.builder.create_condition_op(cond.handle, block_args_flat) # merge the loop body - after_block = self.builder.create_block_with_parent(while_op.get_after(), - [ty.to_ir(self.builder) for ty in ret_types]) + after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys_flat) # generate loop body self.builder.set_insertion_point_to_start(after_block) - for i, name in enumerate(names): - self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) - self.local_defs[name] = self.lscope[name] + block_args_flat = [after_block.arg(i) for i in range(len(init_handles_flat))] + block_args = list_list_unflatten(init_handles_spec, block_args_flat) + for name, init_val, arg_handles in zip(names, init_args, block_args): + val = init_val._unflatten_ir(arg_handles) + self.lscope[name] = val + self.local_defs[name] = val self.scf_stack.append(node) self.visit_compound_statement(node.body) self.scf_stack.pop() @@ -871,12 +886,16 @@ def visit_While(self, node): yields = [] for name in loop_defs: if name in liveins: - yields.append(loop_defs[name]) - self.builder.create_yield_op([y.handle for y in yields]) + yields.append(loop_defs[name]._flatten_ir()) + + _, yields_flat = list_list_flatten(yields) + self.builder.create_yield_op(yields_flat) # WhileOp defines new values, update the symbol table (lscope, local_defs) - for i, name in enumerate(names): - new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + results_flat = [while_op.get_result(i) for i in range(len(init_handles_flat))] + results = list_list_unflatten(init_handles_spec, results_flat) + for name, init_val, result in zip(names, init_args, results): + new_def = init_val._unflatten_ir(result) self.lscope[name] = new_def self.local_defs[name] = new_def @@ -987,34 +1006,45 @@ def visit_For(self, node): # create ForOp self._set_insertion_point_and_loc(ip, last_loc) - for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + init_handles = [a._flatten_ir() for a in init_args] + init_handles_spec, init_handles_flat = list_list_flatten(init_handles) + for_op = self.builder.create_for_op(lb, ub, step, init_handles_flat) if num_stages is not None: for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) self.scf_stack.append(node) - self.builder.set_insertion_point_to_start(for_op.get_body(0)) + for_op_body = for_op.get_body(0) + self.builder.set_insertion_point_to_start(for_op_body) # reset local scope to not pick up local defs from the previous dry run. self.lscope = liveins.copy() self.local_defs = {} - for i, name in enumerate(names): - self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + block_args_flat = [for_op_body.arg(i + 1) for i in range(len(init_handles_flat))] + block_args = list_list_unflatten(init_handles_spec, block_args_flat) + for name, init_val, arg_handles in zip(names, init_args, block_args): + val = init_val._unflatten_ir(arg_handles) + self.set_value(name, val) self.visit_compound_statement(node.body) self.scf_stack.pop() yields = [] for name in self.local_defs: if name in liveins: - yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) + local = self.local_defs[name] + if isinstance(local, constexpr): + local = language.semantic.to_tensor(local, self.builder) + yields.append(local) # create YieldOp if len(yields) > 0: - self.builder.create_yield_op([y.handle for y in yields]) - for_op_region = for_op.get_body(0).get_parent() + yield_handles = [y._flatten_ir() for y in yields] + _, yield_handles_flat = list_list_flatten(yield_handles) + self.builder.create_yield_op(yield_handles_flat) + for_op_region = for_op_body.get_parent() assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" # update induction variable with actual value, and replace all uses - self.builder.set_insertion_point_to_start(for_op.get_body(0)) + self.builder.set_insertion_point_to_start(for_op_body) iv = for_op.get_induction_var() if negative_step: iv = self.builder.create_sub(ub, iv) @@ -1023,8 +1053,11 @@ def visit_For(self, node): self.set_value(node.target.id, language.core.tensor(iv, iv_type)) # update lscope & local_defs (ForOp defines new values) - for i, name in enumerate(names): - self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + result_handles_flat = [for_op.get_result(i) for i in range(len(init_handles_flat))] + result_handles = list_list_unflatten(init_handles_spec, result_handles_flat) + for name, init_val, arg_handles in zip(names, init_args, result_handles): + val = init_val._unflatten_ir(arg_handles) + self.set_value(name, val) for stmt in node.orelse: assert False, "Don't know what to do with else after for" @@ -1142,17 +1175,6 @@ def visit_BoolOp(self, node: ast.BoolOp): _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} - if sys.version_info < (3, 8): - - def visit_NameConstant(self, node): - return constexpr(node.value) - - def visit_Num(self, node): - return constexpr(node.n) - - def visit_Str(self, node): - return constexpr(ast.literal_eval(node)) - def visit_Attribute(self, node): lhs = self.visit(node.value) if _is_triton_tensor(lhs) and node.attr == "T": diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 8ca1f8b326d0..f70c46a9d406 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -15,6 +15,7 @@ import re import functools import os +import sysconfig # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace @@ -24,19 +25,13 @@ # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 -mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { - "ttir": mlir_prototype_pattern, - "ttgir": mlir_prototype_pattern, "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { - "ttir": mlir_arg_type_pattern, - "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } @@ -54,16 +49,6 @@ def convert_type_repr(x): return x -def _get_num_warps_from_ir_str(src: str): - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if - # e.g. someone has an instruction (not module) attribute named "num-warps". - num_warps_matches = re.findall(ttgir_num_warps_pattern, src) - assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - num_warps = int(num_warps_matches[0]) - return num_warps - - class ASTSource: def __init__(self, fn, signature, constants=None, attrs=None) -> None: @@ -106,28 +91,42 @@ def parse_options(self): class IRSource: - def __init__(self, path): + def __init__(self, path, context, backend): self.path = path path = Path(path) self.ext = path.suffix[1:] self.src = path.read_text() - match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) - self.name = match.group(1) - signature = match.group(2) - types = re.findall(arg_type_pattern[self.ext], signature) - self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + ir.load_dialects(context) + backend.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() def make_ir(self, options, codegen_fns, module_map, context): - module = ir.parse_mlir_module(self.path, context) - module.context = context - return module + self.module.context = context + return self.module def parse_options(self): if self.ext == "ttgir": - return {'num_warps': _get_num_warps_from_ir_str(self.src)} + num_warps = self.module.get_int_attr("ttg.num-warps") + assert num_warps is not None, "Unable to parse ttg.num-warps attribute" + return {'num_warps': num_warps} return dict() @@ -151,7 +150,8 @@ def triton_key(): # backend libtriton_hash = hashlib.sha256() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: @@ -223,7 +223,9 @@ def compile(src, target=None, options=None): # create backend if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" - src = IRSource(src) + context = ir.context() + src = IRSource(src, context, backend) + extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager @@ -264,9 +266,14 @@ def compile(src, target=None, options=None): # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. if ir_source: first_stage += 1 - context = ir.context() - ir.load_dialects(context) - backend.load_dialects(context) + + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() try: diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6502a5348f3e..0c8965fc520a 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -28,6 +28,9 @@ TRITON_MAX_TENSOR_NUMEL, _experimental_descriptor_load, _experimental_descriptor_store, + _experimental_make_tensor_descriptor, + _experimental_reinterpret_tensor_descriptor, + _experimental_tensor_descriptor, add, advance, arange, @@ -67,6 +70,7 @@ float8e5b16, full, function_type, + gather, histogram, inline_asm_elementwise, int1, @@ -126,6 +130,9 @@ "TRITON_MAX_TENSOR_NUMEL", "_experimental_descriptor_load", "_experimental_descriptor_store", + "_experimental_make_tensor_descriptor", + "_experimental_reinterpret_tensor_descriptor", + "_experimental_tensor_descriptor", "abs", "add", "advance", @@ -146,7 +153,6 @@ "block_type", "broadcast", "broadcast_to", - "builtin", "cat", "cast", "cdiv", @@ -183,6 +189,7 @@ "fma", "full", "function_type", + "gather", "histogram", "inline_asm_elementwise", "interleave", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e2c57b388bb0..c822ef812cd6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -29,7 +29,6 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: - print(kwargs) raise ValueError("Did you forget to add @triton.jit ? " "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) @@ -610,6 +609,8 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. + assert (isinstance(shape, list)) + # shape can be empty ([]) when an input is a 0D tensor. self.shape = _unwrap_shape(shape) if not self.shape: @@ -715,6 +716,12 @@ class _value: def __init__(self, handle): self.handle = handle + def _flatten_ir(self): + raise NotImplementedError + + def _unflatten_ir(self, handles): + raise NotImplementedError + # ----------------------- # tensor @@ -756,6 +763,13 @@ def __init__(self, handle, type: dtype): self.dtype = type.scalar self.shape = [constexpr(s) for s in self.shape] + def _flatten_ir(self): + return [self.handle] + + def _unflatten_ir(self, handles): + assert len(handles) == 1 + return tensor(handles[0], self.type) + def __str__(self) -> str: # ex. "float32[16, 32]" return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' @@ -1084,6 +1098,9 @@ def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: ... + def gather(self, indices, axis) -> tensor: + ... + def histogram(self, num_bins) -> tensor: ... @@ -1130,6 +1147,83 @@ def flip(self, dim=None) -> tensor: ... +class _experimental_tensor_descriptor_base(_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle) + + self.type = type # Tensor type (block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + + def _flatten_ir(self): + return [self.handle] + + def _unflatten_ir(self, handles): + assert len(handles) == 1 + return _experimental_tensor_descriptor_base(handles[0], self.type) + + @property + def block_shape(self): + return self.type.shape + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.type}>" + + @builtin + def load(self, offsets: List[tensor], _builder=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_load(self, offsets, "", "", _builder) + + @builtin + def store(self, offsets: List[tensor], value: tensor, _builder=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_store(self, value, offsets, _builder) + + +class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, type) + # Global shape + self.shape = shape + self.strides = strides + + def _flatten_ir(self): + handles = [self.handle] + handles.extend(s.handle for s in self.shape) + handles.extend(s.handle for s in self.strides) + return handles + + def _unflatten_ir(self, handles): + ndim = len(self.shape) + assert len(handles) == 2 * ndim + 1 + handle = handles[0] + shape = [tensor(handle, s.type) for handle, s in zip(handles[1:1 + ndim], self.shape)] + strides = [tensor(handle, s.type) for handle, s in zip(handles[1 + ndim:], self.strides)] + return _experimental_tensor_descriptor(handle, shape, strides, self.type) + + def get_bool_env_var(var_name): v = os.getenv(var_name, "0") return v == "1" or v == "true" or v == "on" @@ -1290,6 +1384,7 @@ def trans(input: tensor, *dims, _builder=None): :py:func:`permute` is equivalent to this function, except it doesn't have the special case when no permutation is specified. """ + dims = _unwrap_iterable(dims) if not dims: dims = (1, 0) return semantic.permute(input, dims, _builder) @@ -1522,9 +1617,9 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i where the first dimension of each block represents the batch dimension. :param input: The first tensor to be multiplied. - :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second tensor to be multiplied. - :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} :param acc: The accumulator tensor. If not None, the result is added to this tensor. :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} :param input_precision: How to exercise the Tensor Cores for f32 x f32. If @@ -1555,15 +1650,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf :param lhs: The first tensor to be multiplied. - :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param lhs_scale: Scale factor for lhs tensor. - :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}. + :type lhs_format: str :param rhs: The second tensor to be multiplied. - :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param rhs_scale: Scale factor for rhs tensor. - :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}. + :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ out_dtype = _constexpr_to_value(out_dtype) @@ -1636,6 +1733,16 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c volatile, _builder) +@builtin +def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype, + _builder=None) -> _experimental_tensor_descriptor_base: + """ + Reinterpret a generic pointer as a TMA-backed tensor descriptor object. + """ + block_ty = block_type(_constexpr_to_value(dtype), block_shape) + return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder) + + @builtin def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): """ @@ -1644,8 +1751,8 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder= This loads a tensor of data based on the descriptor and offsets. """ - type = block_type(_constexpr_to_value(dtype), shape) - return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder) + return desc.load(offsets, _builder=_builder) @builtin @@ -1656,7 +1763,8 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): This stores a tensor of data based on the descriptor and offsets. """ - return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder) + return desc.store(offsets, value, _builder=_builder) @_tensor_member_fn @@ -1737,6 +1845,64 @@ def advance(base, offsets, _builder=None): return semantic.advance(base, offsets, _builder) +@builtin +def _experimental_make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> _experimental_tensor_descriptor: + """Make an experimental tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers represeting the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2d tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M / M_BLOCK, N / N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) + + # ----------------------- # Atomic Memory Operations # ----------------------- @@ -2184,6 +2350,23 @@ def histogram(input, num_bins, _builder=None, _generator=None): return semantic.histogram(input, num_bins, _builder) +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + + # ----------------------- # Compiler Hint Ops # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index be157c5b4609..60890ac596eb 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -6,7 +6,6 @@ from .._C.libtriton import ir from . import core as tl -from . import math T = TypeVar('T') @@ -88,11 +87,12 @@ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_i else: return tl.float16 # 4) return bf16 only if both operands are of bf16 - if a_ty.is_bf16() or b_ty.is_bf16(): + if a_ty.is_bf16() and b_ty.is_bf16(): if div_or_mod: return tl.float32 - if a_ty.is_bf16() and b_ty.is_bf16(): + else: return tl.bfloat16 + if a_ty.is_bf16() or b_ty.is_bf16(): return tl.float32 # 5) return fp16 if operands are different fp8 if a_ty.is_fp8() and b_ty.is_fp8(): @@ -333,10 +333,7 @@ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, bu other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): - # input - input.div(other, rounding_mode="floor") * other - floor = math.floor(fdiv(input, other, False, builder), _builder=builder) - ret = sub(input, mul(floor, other, True, builder), True, builder) - return ret + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) # % int elif scalar_ty.is_int(): if scalar_ty.int_signedness != other_scalar_ty.int_signedness: @@ -1141,18 +1138,24 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) -def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, +def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder): + handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder)) + return tl._experimental_tensor_descriptor_base(handle, block_ty) + + +def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), - _str_to_load_cache_modifier(cache_modifier), + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), _str_to_eviction_policy(eviction_policy)) - return tl.tensor(x, type) + return tl.tensor(x, desc.type) -def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: +def descriptor_store(desc: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) def tensormap_create( @@ -1256,7 +1259,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): val = cast(val, elt_ty, builder) # Build IR - if not mask: + if mask is None: return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") @@ -1311,7 +1314,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, if val is not None: val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) val = cast(val, ptr.type.scalar.element_ty, builder) - if not mask: + if mask is None: mask_ir = builder.get_int1(True) mask_ty = tl.int1 if ptr.type.is_block(): @@ -1527,40 +1530,58 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona ret_ty) -def _str_to_fp_type(float_format: Optional[str]): - if float_format == 'e4m3': - return ir.F8F6F4TY.E4M3 - if float_format == 'e5m2': - return ir.F8F6F4TY.E5M2 - if float_format == 'e2m3': - return ir.F8F6F4TY.E2M3 - if float_format == 'e3m2': - return ir.F8F6F4TY.E3M2 - if float_format == 'e2m1': - return ir.F8F6F4TY.E2M1 - raise ValueError(f"Invalid float format: {float_format}.") +def _str_to_fp_type(float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) -def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], - rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" - assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16"} + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None - assert rhs_scale_is_none, "NYI: rhs_scale not supported" + lhs_scale_is_none = isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None + assert rhs_scale_is_none != lhs_scale_is_none, "There should be exactly one operand with scale" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) M = lhs.type.shape[-2] K, N = rhs.type.shape[-2:] - PACKED = 2 if lhs_format == "e2m1" else 1 - assert K == PACKED * lhs.type.shape[ + PACKED_A = 2 if lhs_format == "e2m1" else 1 + PACKED_B = 2 if rhs_format == "e2m1" else 1 + assert K * PACKED_B == PACKED_A * lhs.type.shape[ -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" - assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) @@ -1571,8 +1592,9 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, acc_handle = acc.handle assert acc.type == ret_ty rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle return tl.tensor( - builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, acc_handle), ret_ty) @@ -1655,6 +1677,30 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + + # ===----------------------------------------------------------------------=== # Histogram # ===----------------------------------------------------------------------=== @@ -1663,7 +1709,7 @@ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: assert len(input.shape) == 1, "histogram only supports 1D input" assert input.dtype.is_int(), "histogram only supports integer input" - return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins])) ## @@ -1712,10 +1758,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: if not builder.options.debug: return - cond_ty = cond.type - if not cond_ty.is_block(): - cond_ty = tl.block_type(cond_ty.scalar, (1, )) - cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg), tl.void) @@ -1798,3 +1840,34 @@ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: # Advanced block pointer type is the same as before return tl.tensor(builder.create_advance(base.handle, offsets), base.type) + + +def make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder, +) -> tl._experimental_tensor_descriptor: + ndim = len(shape) + if ndim != 2: + raise ValueError("Only two dimensional tensor descriptors are supported at the moment") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + + if not (isinstance(strides[-1], tl.constexpr) and strides[-1].value == 1): + raise ValueError(f"Tensor descriptor last dim must tl.constexpr(1) but got {strides[-1]}") + + shape = [to_tensor(x, builder) for x in shape] + strides = [to_tensor(x, builder).to(tl.int64, _builder=builder) for x in strides] + + # Check whether `block_shape` is static + block_shape = tl._unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + type = tl.block_type(base.type.element_ty, block_shape) + handle = builder.create_make_tensor_descriptor(base.handle, [s.handle for s in shape], [s.handle for s in strides], + block_shape) + return tl._experimental_tensor_descriptor(handle, shape, strides, type) diff --git a/python/triton/runtime/_allocation.py b/python/triton/runtime/_allocation.py new file mode 100644 index 000000000000..aa8a45488c87 --- /dev/null +++ b/python/triton/runtime/_allocation.py @@ -0,0 +1,32 @@ +from typing import Optional, Protocol + + +class Buffer(Protocol): + + def data_ptr(self) -> int: + ... + + +class Allocator(Protocol): + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + ... + + +class NullAllocator: + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " + + "Use triton.set_allocator to specify an allocator.") + + +_allocator: Allocator = NullAllocator() + + +def set_allocator(allocator: Allocator): + """ + The allocator function is called during kernel launch for kernels that + require additional global memory workspace. + """ + global _allocator + _allocator = allocator diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 5f846de17017..c9833c9482cc 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -7,7 +7,7 @@ from typing import Dict from .jit import KernelInterface -from .errors import OutOfResources +from .errors import OutOfResources, PTXASError from .driver import driver @@ -44,36 +44,40 @@ def __init__( self.arg_names = arg_names # Reset to zero or restore values - self.reset_idx = [] + self.reset_to_zero = [] if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - self.restore_idx = [] + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] if restore_value is not None: - self.restore_idx = [arg_names.index(k) for k in restore_value] + self.restore_value = list(restore_value) # Hook to reset or restore for required tensors - self.pre_hook = lambda args, reset_only=False: 0 - self.post_hook = lambda args, exception: 0 + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False if pre_hook: self.pre_hook = pre_hook - elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): - def _pre_hook(args, reset_only=False): - for i in self.reset_idx: - args[i].zero_() + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() if not reset_only: - self.restore_copies = [args[i].clone() for i in self.restore_idx] + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} self.pre_hook = _pre_hook if post_hook: self.post_hook = post_hook - elif len(self.restore_idx) > 0: + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: - def _post_hook(args, exception): - for i, j in enumerate(self.restore_idx): - args[j].copy_(self.restore_copies[i]) - self.restore_copies = [] + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} self.post_hook = _post_hook @@ -90,6 +94,10 @@ def _post_hook(args, exception): while not inspect.isfunction(self.base_fn): self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + # If we got explicitly called via the old interface, raise a warning # and proceed with the old behavior. if warmup is not None or rep is not None or use_cuda_graph: @@ -123,6 +131,10 @@ def _post_hook(args, exception): def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure + verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1" + if verbose: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() @@ -136,7 +148,7 @@ def _bench(self, *args, config, **meta): def kernel_call(): if config.pre_hook: config.pre_hook(full_nargs) - self.pre_hook(args) + self.pre_hook(full_nargs) try: self.fn.run( *args, @@ -144,16 +156,18 @@ def kernel_call(): ) except Exception as e: try: - self.post_hook(args, exception=e) + self.post_hook(full_nargs, exception=e) finally: # Throw exception raised by `self.fn.run` raise - self.post_hook(args, exception=None) + self.post_hook(full_nargs, exception=None) try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure): + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if verbose: + print(f"Autotuning failed with {e}") return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): @@ -176,7 +190,8 @@ def run(self, *args, **kwargs): bench_end = time.time() self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) - self.pre_hook(args, reset_only=True) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) self.configs_timings = timings config = self.cache[key] else: @@ -186,7 +201,8 @@ def run(self, *args, **kwargs): print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") if config.pre_hook is not None: - config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) ret = self.fn.run( *args, **kwargs, @@ -249,11 +265,12 @@ class Config: function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_threads=0, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.num_threads = num_threads self.maxnreg = maxnreg self.pre_hook = pre_hook @@ -265,6 +282,7 @@ def all_kwargs(self): ("num_warps", self.num_warps), ("num_ctas", self.num_ctas), ("num_stages", self.num_stages), + ("num_threads", self.num_threads), ("maxnreg", self.maxnreg), ) if v is not None } @@ -277,6 +295,7 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"num_threads: {self.num_threads}") res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) @@ -322,12 +341,12 @@ def kernel(x_ptr, x_size, **META): :type restore_value: list[str] :param pre_hook: a function that will be called before the kernel is called. This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. - 'args': a list of arguments passed to the kernel. + 'kwargs': a dict of all arguments passed to the kernel. 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. :type pre_hook: lambda args, reset_only :param post_hook: a function that will be called after the kernel is called. This overrides the default post_hook used for 'restore_value'. - 'args': a list of arguments passed to the kernel. + 'kwargs': a dict of all arguments passed to the kernel. 'exception': the exception raised by the kernel in case of a compilation or runtime error. :type post_hook: lambda args, exception :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 20da2bc25790..58a52fb8f82f 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -1,5 +1,6 @@ import contextlib import sys +import platform import io import sysconfig import os @@ -18,8 +19,19 @@ def quiet(): sys.stdout, sys.stderr = old_stdout, old_stderr +def _is_apple_clang(): + if platform.system() != "Darwin": + return False + res = subprocess.run(["clang", "--version"], capture_output=True, text=True) + if res.returncode != 0: + return False + return "Apple clang" in res.stdout + + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') + system = platform.system() + machine = platform.machine() so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible cc = os.environ.get("CC") @@ -44,9 +56,43 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + + libraries += ["gcc"] + # Use dynamic lookup to load Python library on Mac + if system == "Darwin": + cc_cmd += ["-undefined", "dynamic_lookup"] + # Don't use libgcc on clang + macos + if "clang" in cc: + libraries.remove("gcc") cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + for dir in library_dirs: + cc_cmd.extend(["-Wl,-rpath", dir]) + # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. + if src.endswith(".cpp") or src.endswith(".cc"): + cc_cmd += ["-std=c++17"] + if not os.environ.get("TRITON_DISABLE_OPENMP", None): + libomp_path = os.environ.get("TRITON_LOCAL_LIBOMP_PATH", None) + if _is_apple_clang(): + if libomp_path: + cc_cmd += ["-Xclang"] + cc_cmd += ["-fopenmp"] + cc_cmd += [f"-I{libomp_path}/include"] + cc_cmd += [f"-L{libomp_path}/lib"] + cc_cmd += ["-lomp"] + else: + print("Warning: TRITON_LOCAL_LIBOMP_PATH is not set for Apple clang. OpenMP is disabled.") + else: + cc_cmd += ["-fopenmp"] + if libomp_path: + print("Info: Ignoring TRITON_LOCAL_LIBOMP_PATH for non-Apple clang compiler") + if src.endswith(".s"): + # This is required to properly parse .file directives + cc_cmd += ["-g"] + if system == "Linux" and machine in ("aarch64", "arm64"): + # On Arm backend, some CPU (neoverse-v2) needs to be specified through -mcpu + cc_cmd += ["-mcpu=native"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 82b2fea37e9b..62895508b019 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -256,9 +256,9 @@ def put_group(self, filename: str, group: Dict[str, str]): __cache_cls_nme = "DEFAULT" -def _base64(key): +def _base32(key): # Assume key is a hex string. - return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") def get_cache_manager(key) -> CacheManager: @@ -274,15 +274,15 @@ def get_cache_manager(key) -> CacheManager: __cache_cls = getattr(module, clz_nme) __cache_cls_nme = user_cache_manager - return __cache_cls(_base64(key)) + return __cache_cls(_base32(key)) def get_override_manager(key) -> CacheManager: - return __cache_cls(_base64(key), override=True) + return __cache_cls(_base32(key), override=True) def get_dump_manager(key) -> CacheManager: - return __cache_cls(_base64(key), dump=True) + return __cache_cls(_base32(key), dump=True) def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): @@ -292,4 +292,4 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): for kw in kwargs: key = f"{key}-{kwargs.get(kw)}" key = hashlib.sha256(key.encode("utf-8")).hexdigest() - return _base64(key) + return _base32(key) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c3b97a764145..ed3c16978bd2 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -1,9 +1,19 @@ +import os + from ..backends import backends from ..backends import DriverBase def _create_driver(): + if os.getenv("TRITON_CPU_BACKEND", "0") == "1": + if "cpu" not in backends: + raise RuntimeError("TRITON_CPU_BACKEND is set, but CPU backend is unavailable.") + return backends["cpu"].driver() + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) >= 2 and backends["cpu"].driver.is_active(): + print("Both CPU and GPU backends are available. Using the GPU backend.") + actives.remove(backends["cpu"].driver) if len(actives) != 1: raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") return actives[0]() @@ -56,5 +66,22 @@ def set_active(self, driver: DriverBase): def reset_active(self): self.active = self.default + def set_active_to_cpu(self): + if "cpu" not in backends: + raise RuntimeError("CPU backend is unavailable") + self.active = backends["cpu"].driver() + + def set_active_to_gpu(self): + active_gpus = [(name, backend.driver) + for name, backend in backends.items() + if backend.driver.is_active() and name != "cpu"] + if len(active_gpus) != 1: + raise RuntimeError(f"{len(active_gpus)} active GPU drivers ({active_gpus}). There should only be one GPU.") + self.active = active_gpus[0][1]() + return active_gpus[0][0] + + def get_active_gpus(self): + return [name for name, backend in backends.items() if backend.driver.is_active() and name != "cpu"] + driver = DriverConfig() diff --git a/python/triton/runtime/errors.py b/python/triton/runtime/errors.py index 4dce9176709a..1a8046430eca 100644 --- a/python/triton/runtime/errors.py +++ b/python/triton/runtime/errors.py @@ -24,3 +24,13 @@ def __str__(self) -> str: def __reduce__(self): # this is necessary to make CompilationError picklable return (type(self), (self.required, self.limit, self.name)) + + +class PTXASError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"PTXAS error: {error_message}" diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 0aeaff73a4ea..3b94f55ea3c0 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -21,7 +21,7 @@ def __init__(self, data, dtype): ''' data: numpy array dtype: triton type, either pointer_type or scalar_type. - we don't store block_type here because the shape information is already availale in the data field + we don't store block_type here because the shape information is already available in the data field attr: a dictionary of attributes ''' self.data = data @@ -46,27 +46,26 @@ def set_attr(self, key, value): class BlockPointerHandle: - def __init__(self, base, shape, strides, offsets, tensor_shape, order): + def __init__(self, base, shape, strides, offsets, block_shape, order): self.base = base self.shape = shape self.strides = strides self.offsets = offsets - self.tensor_shape = tensor_shape + self.block_shape = block_shape self.order = order def materialize_pointers(self, boundary_check): dtype_tt = self.base.get_element_ty() n_bytes = dtype_tt.primitive_bitwidth // 8 - tensor_shape = self.tensor_shape - ptrs = np.broadcast_to(self.base.data, self.tensor_shape) - masks = np.ones(self.tensor_shape, dtype=bool) - for dim in range(len(tensor_shape)): - bcast_dims = [1] * len(tensor_shape) - bcast_dims[dim] = tensor_shape[dim] - off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) if dim in boundary_check: - masks = np.logical_and(masks, off < self.shape[dim].data) + masks = masks & (off < self.shape[dim].data) & (off >= 0) ptrs = TensorHandle(ptrs, self.base.dtype.scalar) return ptrs, masks @@ -419,7 +418,7 @@ def binary_op(self, lhs, rhs, op): create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) - create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) @@ -655,17 +654,17 @@ def create_barrier(self): # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter pass - def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order): # Create new offsets to avoid modifying the original new_offsets = [offset.clone() for offset in offsets] - return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order) def create_advance(self, ptr, offsets): if len(ptr.offsets) != len(offsets): raise ValueError("len(ptr.offsets) != len(offsets)") # Create new offsets to avoid modifying the original new_offsets = [offset.clone() for offset in ptr.offsets] - ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order) for i in range(len(offsets)): ret.offsets[i].data += offsets[i].data return ret @@ -728,7 +727,7 @@ def check_tensor(self, input): def to_tensor(self, ret, dtype): if hasattr(ret, "shape") and ret.shape: - ret_type = tl.block_type(dtype, ret.shape) + ret_type = tl.block_type(dtype, list(ret.shape)) else: ret = np.array([ret]).astype(_get_np_dtype(dtype)) ret_type = dtype @@ -1034,9 +1033,6 @@ def _implicit_cvt(arg): interpreter_builder = InterpreterBuilder() -# These keywords are not supported by the interpreter -RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] - class GridExecutor: @@ -1077,10 +1073,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) def __call__(self, *args_dev, **kwargs): - # removes reserved keywords from kwargs - kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} if kwargs.pop("warmup", False): return + # Removes not used reserved keywords from kwargs + # Triton doesn't support keyword-only, variable positional or variable keyword arguments + # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) + argspec = inspect.getfullargspec(self.fn) + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # copy arguments to the host args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # remaps core language functions to interpreted ones diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40bb29..1db10dee082f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -439,6 +439,12 @@ def create_function_from_signature(sig, kparams, backend): type_canonicalisation_dict[v] = v +def get_device_key(): + target = driver.active.get_current_target() + device = driver.active.get_current_device() + return f"{target.backend}:{device}" + + class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None @@ -561,7 +567,7 @@ def create_binder(self, backend): ] def run(self, *args, grid, warmup, **kwargs): - kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" # parse options from ..compiler import make_backend @@ -580,8 +586,9 @@ def run(self, *args, grid, warmup, **kwargs): bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) # compute cache key + device_key = get_device_key() key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) - kernel = self.cache[device].get(key, None) + kernel = self.cache[device_key].get(key, None) if kernel is None: # Kernel is not cached; we have to compile. @@ -625,7 +632,7 @@ def run(self, *args, grid, warmup, **kwargs): target=target, options=options.__dict__, ) - self.cache[device][key] = kernel + self.cache[device_key][key] = kernel self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. @@ -698,6 +705,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel = None + self.debug = debug self.noinline = noinline # TODO(jlebar): Remove uses of these fields outside this file, then @@ -732,7 +740,7 @@ def preload(self, specialization_data): from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl - device = driver.active.get_current_device() + device_key = get_device_key() deserialized_obj = json.loads(specialization_data) if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( @@ -749,7 +757,7 @@ def preload(self, specialization_data): } key = deserialized_obj['key'] kernel = compile(src, None, options) - self.cache[device][key] = kernel + self.cache[device_key][key] = kernel return kernel # we do not parse `src` in the constructor because @@ -917,6 +925,9 @@ def clone(self): def to(self, device): return TensorWrapper(self.base.to(device), self.dtype) + def new_empty(self, sizes): + return TensorWrapper(self.base.new_empty(sizes), self.dtype) + def reinterpret(tensor, dtype): if isinstance(tensor, TensorWrapper): diff --git a/python/triton/testing.py b/python/triton/testing.py index 71cb8ab1eaaa..00adf3ebb9aa 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -92,7 +92,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", + measure_time_with_hooks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -119,7 +120,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m cache = runtime.driver.active.get_empty_cache_for_benchmark() - # Estimate the runtime of the function start_event = di.Event(enable_timing=True) end_event = di.Event(enable_timing=True) start_event.record() @@ -130,6 +130,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 + # For CPU we can use entry and exit hooks to measure execution time + # more precisely. + if measure_time_with_hooks: + di.enable_hook_timing() + # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) @@ -154,6 +159,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m end_event[i].record() # Record clocks di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) diff --git a/python/triton/tools/compile.c b/python/triton/tools/compile.c index 971bf61912a7..24b369354503 100644 --- a/python/triton/tools/compile.c +++ b/python/triton/tools/compile.c @@ -60,6 +60,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{ unsigned int gX = {gridX}; unsigned int gY = {gridY}; unsigned int gZ = {gridZ}; + CUdeviceptr global_scratch = 0; void *args[{num_args}] = {{ {arg_pointers} }}; // TODO: shared memory if(gX * gY * gZ > 0) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 443341fa0d47..6adf7794cc44 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -113,6 +113,9 @@ def constexpr(s): src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) + if ccinfo.metadata.global_scratch_size > 0: + raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented") + arg_names = [] arg_types = [] arg_names_not_1 = [] @@ -138,8 +141,8 @@ def constexpr(s): "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), - "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]), - "num_args": len(arg_names_not_1), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]), + "num_args": len(arg_names_not_1) + 1, "kernel_docstring": doc_string, "shared": ccinfo.metadata.shared, "num_warps": args.num_warps, diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index e0220a45ce04..93cf90ae7a1b 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,6 +23,12 @@ import triton import triton.language as tl +GPU_BLOCK_SIZE = 1024 +CPU_BLOCK_SIZE = 4096 +# Single Thread Threshold +CPU_ST_THRESHOLD = 65536 +USE_GPU = False + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -52,15 +58,68 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) +@triton.jit +def add_kernel_tiled(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + TILE_SIZE: tl.constexpr, # Number of elements each iteration should process. + # NOTE `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)): + offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +@triton.autotune( + configs=[ + # For small vectors it might be faster to use a single thread instead + # of paying OMP threading overhead, so add a single-threaded option. + # Other options use all available threads. + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 4096}, num_threads=1), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 4096}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 8192}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 16384}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 32768}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 65536}, num_threads=0), + ], + key=['n_elements'], +) +@triton.jit +def add_kernel_tiled_autotuned(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + TILE_SIZE: tl.constexpr, # Number of elements each iteration should process. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)): + offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + # %% # Let's also declare a helper function to (1) allocate the `z` tensor # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor): - # We need to preallocate the output. - output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda +def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): + if output is None: + # We need to preallocate the output. + output = torch.empty_like(x) n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -70,25 +129,85 @@ def add(x: torch.Tensor, y: torch.Tensor): # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if device == 'cpu' else GPU_BLOCK_SIZE) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output +def add_tiled(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE, TILE_SIZE=16) + return output + + +def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # TODO: try to choose the best block size using autotuner + BLOCK_SIZE = triton.next_power_of_2(n_elements) + if BLOCK_SIZE > CPU_ST_THRESHOLD: + BLOCK_SIZE = CPU_BLOCK_SIZE + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE, TILE_SIZE=16) + return output + + +def add_tiled_autotuned(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel_tiled_autotuned[grid](x, y, output, n_elements) + return output + + # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: - torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') -output_torch = x + y -output_triton = add(x, y) -print(output_torch) -print(output_triton) -print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + +triton.runtime.driver.set_active_to_cpu() +x = torch.rand(size, device='cpu') +y = torch.rand(size, device='cpu') +output_torch_cpu = torch.add(x, y) +output_triton_cpu = add(x, y, None, device='cpu') +print(output_torch_cpu) +print(output_triton_cpu) +print(f'The maximum difference between torch-cpu and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') +output_triton_cpu = add_tiled(x, y, None) +print(f'The maximum difference between torch-cpu-tiled and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') + +LINE_VALS = [ + 'triton-cpu', 'triton-cpu-hooks', 'triton-cpu-tiled', 'triton-cpu-tiled-hooks', 'triton-cpu-tiled-tuned-hooks', + 'triton-cpu-tiled-autotuned-hooks', 'torch-cpu' +] +LINE_NAMES = [ + 'TritonCPU', 'TritonCPU (hooks)', 'TritonCPUTiled', 'TritonCPUTiled (hooks)', 'TritonCPUTiled (tuned, hooks)', + 'TritonCPUTiled (autotuned, hooks)', 'TorchCPU' +] +LINE_STYLES = [('blue', '--'), ('blue', '-.'), ('red', '-'), ('red', '--'), ('red', '-.'), ('red', ':'), ('green', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to('cuda') + y = y.to('cuda') + output_torch_gpu = x + y + output_triton_gpu = add(x, y, None, device='cuda') + print(output_torch_gpu) + print(output_triton_gpu) + print(f'The maximum difference between torch-gpu and triton-gpu is ' + f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] # %% # Seems like we're good to go! @@ -108,21 +227,52 @@ def add(x: torch.Tensor, y: torch.Tensor): x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. x_log=True, # x axis is logarithmic. line_arg='provider', # Argument name whose value corresponds to a different line in the plot. - line_vals=['triton', 'torch'], # Possible values for `line_arg`. - line_names=['Triton', 'Torch'], # Label name for the lines. - styles=[('blue', '-'), ('green', '-')], # Line styles. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. ylabel='GB/s', # Label name for the y-axis. - plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): - x = torch.rand(size, device='cuda', dtype=torch.float32) - y = torch.rand(size, device='cuda', dtype=torch.float32) + + device = 'cpu' if 'cpu' in provider else 'cuda' + x = torch.rand(size, device=device, dtype=torch.float32) + y = torch.rand(size, device=device, dtype=torch.float32) + + if device == 'cpu': + triton.runtime.driver.set_active_to_cpu() + else: + triton.runtime.driver.set_active_to_gpu() + output = torch.empty_like(x) + quantiles = [0.5, 0.2, 0.8] - if provider == 'torch': + if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) - if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) + elif provider == 'torch-cpu': + # Note that we preallocate the output buffer here to only measure the kernel performance + # without a large chunk of memory allocation. + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles) + elif provider == 'triton-cpu-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles, + measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles) + elif provider == 'triton-cpu-tiled-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, + measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled-tuned-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output), + quantiles=quantiles, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled-autotuned-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_autotuned(x, y, output), quantiles=quantiles, + measure_time_with_hooks=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py new file mode 100644 index 000000000000..e93ed9d37b50 --- /dev/null +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -0,0 +1,236 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the GPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + +import triton +import triton.language as tl + +USE_GPU = False + + +@torch.jit.script +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a row of the input matrix X, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): + # The rows of the softmax are independent, so we parallelize across those + row_idx = tl.program_id(0) + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. + + +def softmax(x, y=None, num_threads=0): + n_rows, n_cols = x.shape + # The block size is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + if y is None: + y = torch.empty_like(x) + # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row of + # the input matrix + softmax_kernel[(n_rows, )]( + y, + x, + x.stride(0), + y.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + num_threads=num_threads, + ) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. + +triton.runtime.driver.set_active_to_cpu() + +torch.manual_seed(0) +x = torch.randn(1823, 781, device='cpu') +y_triton_cpu = softmax(x) +y_torch_cpu = torch.softmax(x, axis=1) +assert torch.allclose(y_triton_cpu, y_torch_cpu), (y_triton_cpu, y_torch_cpu) + +LINE_VALS = [ + 'triton-cpu-single', + 'triton-cpu', + 'torch-cpu-compile', + 'torch-cpu-jit', + 'torch-cpu-native', +] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (compile)', 'TorchCPU (jit)', 'TorchCPU (native)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-'), ('green', '--'), ('green', '-.')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to('cuda') + y_triton_gpu = softmax(x) + y_torch_gpu = torch.softmax(x, axis=1) + assert torch.allclose(y_triton_gpu, y_torch_gpu), (y_triton_gpu, y_torch_gpu) + LINE_VALS += ['triton-gpu', 'torch-gpu-native', 'torch-gpu-jit'] + LINE_NAMES += ['TritonGPU', 'TorchGPU (native)', 'TorchGPU (jit)'] + LINE_STYLES += [('yellow', '-'), ('red', '-'), ('red', '--')] + +# %% +# As expected, the results are identical. + +# %% +# Benchmark +# --------- +# +# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. +# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], # argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 52, 2)], # different possible values for `x_name` + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel="GB/s", # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` + )) +def benchmark(M, N, provider): + + # Currently compilation time is very long. Let's show the progress. + print(f"Running {provider} with {M} x {N}...") + + device = 'cpu' if 'cpu' in provider else 'cuda' + x = torch.randn(M, N, device=device, dtype=torch.float32) + + if device == 'cpu': + y = torch.empty_like(x) + triton.runtime.driver.set_active_to_cpu() + else: + y = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) + if provider == 'torch-cpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) + if provider == 'torch-cpu-compile': + compiled = torch.compile(naive_softmax) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles) + if provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y, num_threads=1), quantiles=quantiles) + if provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles) + if provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) + if provider == 'torch-gpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) + if provider == 'torch-gpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +benchmark.run(show_plots=True, print_data=True) + +# %% +# In the above plot, we can see that: +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index d08afb1e59d2..c98042559770 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -139,35 +139,32 @@ def softmax(x): y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. - kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) - if kernel is None: - kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, - num_stages=num_stages, num_warps=num_warps, grid=(1, )) - kernel._init_handles() - n_regs = kernel.n_regs - size_smem = kernel.metadata.shared - if is_hip(): - # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. - # However, this is not always the case. In most cases all registers can be used as regular purpose registers. - # ISA SECTION (3.6.4 for CDNA3) - # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used - # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total - # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is - # not required to be equal numbers of both types. - if is_cdna(): - NUM_GPRS = NUM_REGS * 2 - - # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. - # When we divide this number with WARP_SIZE we get maximum number of waves that can - # execute on a CU (multi-processor) in parallel. - MAX_NUM_THREADS = properties["max_threads_per_sm"] - max_num_waves = MAX_NUM_THREADS // WARP_SIZE - occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps - else: - occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) - occupancy = min(occupancy, SIZE_SMEM // size_smem) - num_programs = NUM_SM * occupancy - kernels[BLOCK_SIZE] = (kernel, num_programs) + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py new file mode 100644 index 000000000000..a41b4823b72d --- /dev/null +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -0,0 +1,537 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP32 matrix multiplication kernel. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = triton.program_id(0); +# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M; +# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N; +# pid_m = pid / grid_n; +# pid_n = pid % grid_n; +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + (pid % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch +import os + +import triton +import triton.language as tl + +# It depends on CPU cache sizes. +BLOCK_SIZE_M = 64 +BLOCK_SIZE_N = 64 +BLOCK_SIZE_K = { "32": 32, + "64": 64, + "512": 512 }[os.getenv("BLOCK_SIZE_K", "64")] +GROUP_SIZE_M = 4 +USE_GPU = False +USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "0") != "0" +DATA_TYPE = { "f32": torch.float32, + "bf16": torch.bfloat16, + "bf8": torch.float8_e5m2 }[os.getenv("DATATYPE", "f32")] +K_DIM_PADDING = os.getenv("K_DIM_PADDING", "0") != "0" +DYNAMIC_K_BLOCK = os.getenv("DYNAMIC_K_BLOCK", "0") != "0" +CACHE_PADDING = os.getenv("CACHE_PADDING", "0") != "0" +PREPROCESS_EXTERNAL = os.getenv("PREPROCESS_EXTERNAL", "0") != "0" +XSMM_PAD = os.getenv("XSMM_PAD", "0") != "0" +PAD_B_ONLY = os.getenv("PAD_B_ONLY", "0") != "0" + +xsmm_py = None +if XSMM_PAD: + import xsmm_py + +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, # arg0 + b_ptr, # arg1 + c_ptr, # arg2 + # Matrix dimensions + M, # arg3 + N, # arg4 + K, # arg5 + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, # arg6 + stride_ak, # arg7 + stride_bk, # arg8 + stride_bn, # arg9 + stride_cm, # arg11 + stride_cn, # arg12 + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + USE_BLOCK_POINTERS: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if USE_BLOCK_POINTERS: + block_offset_m = pid_m * BLOCK_SIZE_M + block_offset_n = pid_n * BLOCK_SIZE_N + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + if USE_BLOCK_POINTERS: + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0) + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0) + ) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to matrix C's type after the loop, if C has lower precision type (for example, float16 and bfloat16). + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + + if USE_BLOCK_POINTERS: + # TODO: Currently masked load is not supported yet. + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + else: + # TODO: Currently masked load is not supported yet. + # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + # Advance the ptrs to the next K block. + if USE_BLOCK_POINTERS: + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) + else: + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Convert the accumulator to the output matrix C's type if needed. + c = accumulator.to(c_ptr.type.element_ty) + + if USE_BLOCK_POINTERS: + # TODO: masking + c_block_ptr = tl.make_block_ptr( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + offsets=(block_offset_m, block_offset_n), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0) + ) + tl.store(c_block_ptr, c, boundary_check=(0, 1)) + else: + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_tile_ptr, c) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. +a_scratch = torch.empty((), dtype=DATA_TYPE) +b_scratch = torch.empty((), dtype=DATA_TYPE) +def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + k_block = BLOCK_SIZE_K + + if DYNAMIC_K_BLOCK: + # Currently, the maximum dynamic block size is capped somewhat arbitrarily. + # Ideally, tradeoffs between amount of padding, block size, and associated costs + # should be considered. + k_block = min(triton.next_power_of_2(K), 1024) + + if XSMM_PAD: + k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K + col_pad = 32 if CACHE_PADDING else 0 + a_scratch.resize_(M, K + k_dim_pad + col_pad) + b_scratch.resize_(K + k_dim_pad, N + col_pad) + if not PAD_B_ONLY or k_dim_pad != 0: + xsmm_py.fastZeroPad2D(a, a_scratch) + a = a_scratch + xsmm_py.fastZeroPad2D(b, b_scratch) + b = b_scratch + K = K + k_dim_pad + else: + if K_DIM_PADDING or DYNAMIC_K_BLOCK: + k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K + if k_dim_pad != 0: + a = torch.nn.functional.pad(a, (0, k_dim_pad, 0, 0), mode='constant', value=0) + b = torch.nn.functional.pad(b, (0, 0, 0, k_dim_pad), mode='constant', value=0) + K = a.shape[1] + + # TODO: Check if padding is needed at all. + # Currently, cache padding is most useful together with dynamic K blocking + # to ensure that stride is non-power-of-two to improve cache behavior. + if CACHE_PADDING: + if not PAD_B_ONLY: + a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) + b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0) + + #TODO: Currently masked load is not supported yet. + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( + K % k_block == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + if c is None: + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + else: + assert c.shape == (M, N), "Incompatible dimensions" + + return a, b, c, M, N, K, k_block + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int, k_block: int, num_threads=0): + if not PREPROCESS_EXTERNAL: + a, b, c, M, N, K, k_block = matmul_preprocess_input(a, b, c) + + # 1D launch kernel where each block gets its own program. + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=k_block, # + GROUP_SIZE_M=GROUP_SIZE_M, # + USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, # + num_threads=num_threads + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +a = torch.randn((512, 512), device='cpu').type(DATA_TYPE) +b = torch.randn((512, 512), device='cpu').type(DATA_TYPE) +torch_output = torch.matmul(a, b) +c = None +m_dim = None +n_dim = None +k_dim = None +k_block = None +if PREPROCESS_EXTERNAL: + a, b, c, m_dim, n_dim, k_dim, k_block = matmul_preprocess_input(a, b, c) +triton_output = matmul(a, b, c, m_dim, n_dim, k_dim, k_block) +print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") +print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output.type(torch.float64), torch_output.type(torch.float64), atol=1e-2, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +elif DATA_TYPE in {torch.bfloat16, torch.float8_e5m2} and torch.allclose(triton_output.type(torch.float64), torch_output.type(torch.float64), atol=2e-0, rtol=rtol): + print("⚠️ TritonCPU and TorchCPU rounding errors, the maximum difference is " + f'{torch.max(torch.abs(triton_output.type(torch.float64) - torch_output.type(torch.float64)))}') +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output.type(torch.float64) - torch_output.type(torch.float64)))}') + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + +BENCHMARK_BACKEND = os.getenv("BENCHMARK_BACKEND") +if BENCHMARK_BACKEND: + # if BENCHMARK_BACKEND is provided we run the benchmark on just one config on just one backend + assert(BENCHMARK_BACKEND in {"triton-cpu", "triton-xsmm", "torch-cpu-native", "torch-cpu-compile"}) + LINE_VALS = [BENCHMARK_BACKEND] + LINE_NAMES = [BENCHMARK_BACKEND] + LINE_STYLES = [('blue', '-')] +else: + # if BENCHMARK_BACKEND is not provided stick with the default of running multiple backends + LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile'] + LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)'] + LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')] + if DATA_TYPE == torch.float8_e5m2: + LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu'] + LINE_NAMES = ['TritonCPU 1', 'TritonCPU'] + LINE_STYLES = [('blue', '--'), ('blue', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + a = a.to('cuda') + b = b.to('cuda') + triton_output = matmul(a, b, None) + torch_output = torch.matmul(a, b) + print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") + print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- +# +# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops. +# for different problem sizes. + +STR_TYPE = str(DATA_TYPE).rsplit('.')[-1] + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 25)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'matmul-performance-{STR_TYPE} (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, K, provider): + device = 'cpu' + a = torch.randn((M, K), device=device).type(DATA_TYPE) + b = torch.randn((K, N), device=device).type(DATA_TYPE) + + if device == 'cpu': + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + triton.runtime.driver.set_active_to_cpu() + else: + c = None + triton.runtime.driver.set_active_to_gpu() + + triton_a = a + triton_b = b + triton_c = c + m_dim = M + n_dim = N + k_dim = K + k_block = BLOCK_SIZE_K + if PREPROCESS_EXTERNAL: + triton_a, triton_b, triton_c, m_dim, n_dim, k_dim, k_block = matmul_preprocess_input(a, b, c) + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), rep=1000, quantiles=quantiles) + elif provider == 'torch-cpu-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), rep=1000, quantiles=quantiles) + elif provider in {'triton-cpu', 'triton-xsmm'}: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, triton_c, m_dim, n_dim, k_dim, k_block), rep=1000, quantiles=quantiles) + else: + assert(False and "unknown provider") + perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/python/tutorials/03-matrix-multiplication-cpu.sh b/python/tutorials/03-matrix-multiplication-cpu.sh new file mode 100755 index 000000000000..85313adbcd6a --- /dev/null +++ b/python/tutorials/03-matrix-multiplication-cpu.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +config=$1 +shift +numthreads=$1 +shift + +block_pointers_via_raising=0 + +export BENCHMARK_BACKEND="triton-xsmm" + +while [[ $# -gt 0 ]]; do + case $1 in + --raise-block-pointers) + shift + block_pointers_via_raising=1 + ;; + --external-pad) + shift + export PREPROCESS_EXTERNAL=1 + ;; + --datatype) + shift + export DATATYPE=$1 + shift + ;; + --backend) + shift + export BENCHMARK_BACKEND=$1 + shift + ;; + *) + echo "ERROR: unknown argument: $1" + exit 1 + ;; + esac +done + + +if [ "$config" = "baseline" ]; then + if [ "$BENCHMARK_BACKEND" != "torch-cpu-native" ] && [ "$BENCHMARK_BACKEND" != "torch-cpu-compile" ]; then + echo "ERROR: baseline config but backend is not torch-cpu-native or torch-cpu-compile"; exit 1 + fi + if [ "$DATATYPE" == "bf8" ]; then + echo "ERROR: torch-cpu-native and torch-cpu-compile are too slow on bf8 (~1GFLOPS)"; exit 1 + fi +elif [ "$config" = "baseline-scalar" ]; then + export PREPROCESS_EXTERNAL=1 # elide a no-effect py func call during benchmark + if [ "$BENCHMARK_BACKEND" != "triton-cpu" ]; then + echo "ERROR: baseline-scalar config but backend is not triton-cpu"; exit 1 + fi +elif [ "$config" = "baseline-block" ]; then + export PREPROCESS_EXTERNAL=1 # elide a no-effect py func call during benchmark + if [ "$BENCHMARK_BACKEND" != "triton-cpu" ]; then + echo "ERROR: baseline-block config but backend is not triton-cpu"; exit 1 + fi + export USE_BLOCK_POINTERS=1 +elif [ "$config" = "xsmm-scalar" ]; then + export PREPROCESS_EXTERNAL=1 # elide a no-effect py func call during benchmark + if [ "$BENCHMARK_BACKEND" != "triton-xsmm" ]; then + echo "ERROR: xsmm config but backend is not triton-xsmm"; exit 1 + fi + export TRITON_CPU_TRITON_XSMM=1 +elif [ "$config" = "xsmm-block" ]; then + export PREPROCESS_EXTERNAL=1 # elide a no-effect py func call during benchmark + if [ "$BENCHMARK_BACKEND" != "triton-xsmm" ]; then + echo "ERROR: xsmm config but backend is not triton-xsmm"; exit 1 + fi + if [ $block_pointers_via_raising = 1 ]; then + export TRITON_CPU_RAISE_BLOCK_POINTER=1 + else + export USE_BLOCK_POINTERS=1 + fi + export TRITON_CPU_TRITON_XSMM=1 +elif [ "$config" = "xsmm-pad-k" ]; then + if [ "$BENCHMARK_BACKEND" != "triton-xsmm" ]; then + echo "ERROR: xsmm config but backend is not triton-xsmm"; exit 1 + fi + if [ $block_pointers_via_raising = 1 ]; then + export TRITON_CPU_RAISE_BLOCK_POINTER=1 + else + export USE_BLOCK_POINTERS=1 + fi + export XSMM_PAD=1 + export K_DIM_PADDING=1 + export CACHE_PADDING=1 + export BLOCK_SIZE_K=512 + export TRITON_CPU_TRITON_XSMM=1 +elif [ "$config" = "xsmm-loop-collapse-pad-b" ]; then + if [ "$BENCHMARK_BACKEND" != "triton-xsmm" ]; then + echo "ERROR: xsmm config but backend is not triton-xsmm"; exit 1 + fi + if [ $block_pointers_via_raising = 1 ]; then + export TRITON_CPU_RAISE_BLOCK_POINTER=1 + else + export USE_BLOCK_POINTERS=1 + fi + export XSMM_PAD=1 + export PAD_B_ONLY=1 + export BLOCK_SIZE_K=32 + export CACHE_PADDING=1 + export TRITON_CPU_LOOP_BRGEMM_XSMM=1 +elif [ "$config" = "xsmm-external-pad" ]; then + echo "NOT A TRUE CONFIG; try --external-pad on another config" + exit 1 +else + echo "ERROR: unrecognized config: $config" + exit 1 +fi + +# Uses the libxsmm built in the repo +export XSMM_LIB_DIR=$SCRIPT_DIR/../triton/_C/ +export LD_LIBRARY_PATH=$XSMM_LIB_DIR:$LD_LIBRARY_PATH +export LD_PRELOAD=/lib64/libomp.so:$LD_PRELOAD +if [ -e "$numthreads" ]; then + echo "ERROR: must specify numthreads as 2nd arg"; exit 1 +fi +export TRITON_CPU_MAX_THREADS=${numthreads} +export OMP_NUM_THREADS=${numthreads} + +# Thread affinity changes with hyper-threading +THREADS_PER_CORE=$(lscpu | grep --color=never "Thread.*core" | tee - | grep -o "[0-9]\+") +SKIP=$((THREADS_PER_CORE-1)) # 0 for no HT, 1 for 2, 3 for 4, etc. +export KMP_AFFINITY=granularity=fine,compact,$SKIP,0 + +python $SCRIPT_DIR/03-matrix-multiplication-cpu.py diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index a8bfc46a1630..389c859b02f0 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -70,11 +70,12 @@ def dropout(x, x_keep, p): return output +device = triton.runtime.driver.active.get_current_target().backend # Input tensor -x = torch.randn(size=(10, )).cuda() +x = torch.randn(size=(10, ), device=device) # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() +x_keep = (torch.rand(size=(10, ), device=device) > p).to(torch.int32) # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ @@ -138,7 +139,7 @@ def seeded_dropout(x, p, seed): return output -x = torch.randn(size=(10, )).cuda() +x = torch.randn(size=(10, ), device=device) # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index a234153a047e..6726ae72609f 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -288,6 +288,9 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply +device = triton.runtime.driver.active.get_current_target().backend +# Torch doesn't support operations in float16 on CPU so use float32 instead +dtype = torch.float32 if device == 'cpu' else torch.float16 def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): @@ -326,7 +329,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', - args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + args={'M': 4096, 'dtype': dtype, 'mode': 'backward'}, )) def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): # create data @@ -364,8 +367,8 @@ def y_fwd(): return gbps(ms), gbps(max_ms), gbps(min_ms) -test_layer_norm(1151, 8192, torch.float16) -bench_layer_norm.run(save_path='.', print_data=True) +test_layer_norm(1151, 8192, dtype, device=device) +bench_layer_norm.run(save_path='.', print_data=True, device=device) # %% # References diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index bf5f0acf9609..f6dbc97bc093 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -47,12 +47,13 @@ def asin_kernel( # ----------------------------------------- # We can use the default libdevice library path encoded in `triton/language/math.py` +device = triton.runtime.driver.active.get_current_target().backend + torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='cuda') -output_triton = torch.zeros(size, device='cuda') +x = torch.rand(size, device=device) +output_triton = torch.zeros(size, device=device) output_torch = torch.asin(x) -assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 1464d489bc1c..49a8bb32c4f8 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -20,13 +20,13 @@ """ import argparse -import time import torch import triton import triton.language as tl import triton.tools.experimental_descriptor import triton.profiler as proton +from contextlib import contextmanager if torch.cuda.is_available(): from triton._C.libtriton import nvidia @@ -48,6 +48,8 @@ def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "tiles_per_update" in args: + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -541,7 +543,24 @@ def torch_matmul(a, b): return c -def bench(K, dtype, tiles_per_update, reps=10): +@contextmanager +def proton_context(): + proton.activate(0) + try: + yield + finally: + proton.deactivate(0) + + +def bench_fn(reps, warmup_reps, fn, *args): + for _ in range(warmup_reps): + fn(*args) + with proton_context(): + for _ in range(reps): + fn(*args) + + +def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -549,33 +568,15 @@ def bench(K, dtype, tiles_per_update, reps=10): b = b.T.contiguous() - proton.activate(0) - if cublas is not None: - for _ in range(reps): - cublas_matmul(a, b) - time.sleep(0.01) + bench_fn(reps, warmup_reps, cublas_matmul, a, b) if dtype == torch.float16: - for _ in range(reps): - torch_matmul(a, b) - time.sleep(0.01) - for _ in range(reps): - matmul(a, b.T) - time.sleep(0.01) - for _ in range(reps): - matmul_persistent(a, b.T) - time.sleep(0.01) + bench_fn(reps, warmup_reps, torch_matmul, a, b) + bench_fn(reps, warmup_reps, matmul, a, b.T) + bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): - for _ in range(reps): - matmul_tma_persistent(a, b) - time.sleep(0.01) - with proton.scope( - f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"): - for _ in range(reps): - matmul_device_tma_persistent(a, b, tiles_per_update) - time.sleep(0.01) - - proton.deactivate(0) + bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) + bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update) def validate(M, N, K, dtype, tiles_per_update): diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul-fp32.py new file mode 100644 index 000000000000..8f0f0ebce41a --- /dev/null +++ b/python/tutorials/cpu-blocked-matmul-fp32.py @@ -0,0 +1,373 @@ +""" +Matrix Multiplication +===================== +In this tutorial, matmul on CPU with different input layouts is tested. + +This tutorial is optimized for AMX-enabled CPUs. + +""" + +# %% +# Kernels +# ------- + +import torch + +import triton +import triton.language as tl +import os + +DTYPE = os.getenv("DTYPE", "float32") +# Choose block size depending on dtype. We have more register +# capacity for bfloat16/float16 compared to float32. +BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32 +BLOCK_SIZE_N = 32 +BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32 +GROUP_SIZE_M = 8 + + +# This kernel is used for blocked encoding of input tensors for matmul. +# +# Blocked encoding is used to transform 2D tensor [M, N] into 4D tensor +# [M / BLOCK_SIZE_M, N / BLOCK_SIZE_N, BLOCK_SIZE_M, BLOCK_SIZE_N]. +# This makes following access to blocks in matmul more efficient because +# each block is placed into a contiguous memory fragment and is likely +# to fit a single memory page. +# +# If TRANSPOSED_B is set to True then head dimensions of the RHS +# tensor are transposed. It provides contiguos placement for a column +# of blocks. +# +# If TRANSPOSED_BLOCK_A is set to True then tail dimensions of the LHS +# tensor are transposed. Transposed LHS block better matches FMA lowering +# used by Triton CPU backend which processes RHS block row-by-row and LHS +# block column-by-column. +@triton.jit +def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, + TRANSPOSED_B: tl.constexpr): + tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) + tl.static_assert(BLOCKED_B or not TRANSPOSED_B) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + in_block_m = first_pid_m + (pid % group_size_m) + in_block_n = (pid % num_pid_in_group) // group_size_m + + if BLOCKED_A: + a_out_block_m = in_block_m + A_OUT_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M + A_OUT_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K + A_OUT_BLOCKS_M = M // BLOCK_SIZE_M + A_OUT_BLOCKS_K = K // BLOCK_SIZE_K + A_OUT_STRIDE_M: tl.constexpr = A_OUT_BLOCK_SIZE_K + A_OUT_STRIDE_BLOCK_M = BLOCK_SIZE_M * K + A_OUT_STRIDE_BLOCK_K: tl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_K + for in_block_k in tl.range(in_block_n, A_OUT_BLOCKS_K, N // BLOCK_SIZE_N): + a_out_block_k = in_block_k + a_in_ptr = tl.make_block_ptr(base=in_a, shape=(M, K), strides=(K, 1), + offsets=(in_block_m * BLOCK_SIZE_M, in_block_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0)) + a_out_ptr = tl.make_block_ptr( + base=out_a, shape=(A_OUT_BLOCKS_M, A_OUT_BLOCKS_K, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K), + strides=(A_OUT_STRIDE_BLOCK_M, A_OUT_STRIDE_BLOCK_K, A_OUT_STRIDE_M, 1), + offsets=(a_out_block_m, a_out_block_k, 0, 0), + block_shape=(1, 1, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K), order=(3, 2, 1, 0)) + val = tl.load(a_in_ptr) + if TRANSPOSED_BLOCK_A: + val = val.T + val = tl.reshape(val, (1, 1, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K)) + tl.store(a_out_ptr, val) + + if BLOCKED_B: + B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K + B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N + B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N + B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) + B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N + for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): + b_out_block_k = in_block_n if TRANSPOSED_B else in_block_k + b_out_block_n = in_block_k if TRANSPOSED_B else in_block_n + b_in_ptr = tl.make_block_ptr(base=in_b, shape=(K, N), strides=(N, 1), + offsets=(in_block_k * BLOCK_SIZE_K, in_block_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) + b_out_ptr = tl.make_block_ptr(base=out_b, + shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, B_OUT_STRIDE_K, 1), + offsets=(b_out_block_k, b_out_block_n, 0, 0), + block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), order=(3, 2, 1, 0)) + val = tl.load(b_in_ptr) + val = tl.reshape(val, (1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N)) + tl.store(b_out_ptr, val) + + +# Matmul kernel that computes a single output block [BLOCK_SIZE_M, BLOCK_SIZE_N]. LHS can be in the +# rowmajor, blocked, or blocked transposed encoding. RHS can be in rowmajor, blocked, or transposed +# blocked encoding. +# +# To cover all input layouts, we use 4D block pointers that address a single input block +# [1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N], we choose strides for these block pointers +# appropriately to keep navigation bentween blocks similar for all input encodings. +# +# E.g. for rowmajor LHS we use BLOCK_SIZE_K stride to move to the next block over K axis, but +# for blocked encoding we use BLOCK_SIZE_M * BLOCK_SIZE_K stride. In both cases we then can +# advance using the same (0, 1, 0, 0) offset in the loop. +# +# Reshape is used to remove the heading (1, 1) dimensions, but CPU backend folds it with the load +# operation and it doesn't prevent direct vector loads from the input memory. +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # number of blocks in a group + GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): + # TRANSPOSED_BLOCK_A means that each block in A is transposed. + # It is allowed only for blocked input. + assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) + # TRANSPOSED_B means that blocks of B are reordered but blocks + # itself are not transpoed. It is allowed only for blocked input. + assert (BLOCKED_B or not TRANSPOSED_B) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + block_m = first_pid_m + (pid % group_size_m) + block_n = (pid % num_pid_in_group) // group_size_m + + A_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M + A_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K + A_BLOCKS_M = M // BLOCK_SIZE_M + A_BLOCKS_K = K // BLOCK_SIZE_K + a_stride_k = 1 + a_stride_m = A_BLOCK_SIZE_K if BLOCKED_A else K + a_stride_block_k = A_BLOCK_SIZE_M * A_BLOCK_SIZE_K if BLOCKED_A else A_BLOCK_SIZE_K + a_stride_block_m = BLOCK_SIZE_M * K + + b_stride_n = 1 + b_stride_k = BLOCK_SIZE_N if BLOCKED_B else N + if TRANSPOSED_B: + b_stride_block_n = BLOCK_SIZE_N * K + b_stride_block_k = BLOCK_SIZE_K * BLOCK_SIZE_N + else: + b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else BLOCK_SIZE_N + b_stride_block_k = BLOCK_SIZE_K * N + + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(A_BLOCKS_M, A_BLOCKS_K, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), + strides=(a_stride_block_m, a_stride_block_k, a_stride_m, a_stride_k), + offsets=(block_m, 0, 0, 0), block_shape=(1, 1, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), + order=(3, 2, 1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, + shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), + offsets=(0, block_n, 0, 0), block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(3, 2, 1, 0)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(N, 1), + offsets=(block_m * BLOCK_SIZE_M, block_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_block_ptr).reshape((A_BLOCK_SIZE_M, A_BLOCK_SIZE_K)) + b = tl.load(b_block_ptr).reshape((BLOCK_SIZE_K, BLOCK_SIZE_N)) + + if TRANSPOSED_BLOCK_A: + a = a.T + + c += tl.dot(a, b, out_dtype=tl.float32) + + a_block_ptr = tl.advance(a_block_ptr, (0, 1, 0, 0)) + b_block_ptr = tl.advance(b_block_ptr, (1, 0, 0, 0)) + + tl.store(c_block_ptr, c) + + +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): + #TODO: Currently masked load is not supported yet. + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( + K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + # 1D launch kernel where each block gets its own program. + grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) + if (BLOCKED_A or BLOCKED_B) and not PREPACKED: + block_transpose_combined_kernel[grid]( + a, ab, b, bb, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B) + if BLOCKED_A: + a = ab + if BLOCKED_B: + b = bb + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, num_threads=num_threads) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +a = torch.randn((512, 512), device='cpu', dtype=torch.float32) +b = torch.randn((512, 512), device='cpu', dtype=torch.float32) +c = torch.empty((512, 512), device='cpu', dtype=torch.float32) +torch_output = torch.matmul(a, b) +rtol = 0 +a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) +b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + assert False +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU pre-packed and TorchCPU match") +else: + print("❌ TritonCPU pre-packed and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + assert False + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + + +def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" + + +def encode_torch_provider(single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + return f"torch-cpu-native{'-st' if single_thread else ''}-{dtype}" + + +def decode_provider(provider): + if '-bfloat16' in provider: + dtype = torch.bfloat16 + if '-float16' in provider: + dtype = torch.float16 + elif '-float32' in provider: + dtype = torch.float32 + if 'triton-cpu' in provider: + backend = 'triton-cpu' + elif 'torch-cpu-native' in provider: + backend = 'torch-cpu-native' + elif 'torch-cpu-compile' in provider: + backend = 'torch-cpu-compile' + return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-prepack' in provider, '-st' in provider, dtype + + +BLOCK_TRANSPOSE_A_OPTS = [(False, False)] +BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] +PREPACK_OPTS = [False, True] +SINGLE_THREAD_OPTS = [False] +DTYPE_OPTS = [DTYPE] +LINE_VALS = [ + encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) + for single_thread in SINGLE_THREAD_OPTS + for blocked_a, transposed_a in BLOCK_TRANSPOSE_A_OPTS + for blocked_b, transposed_b in BLOCK_TRANSPOSE_B_OPTS + for prepack in PREPACK_OPTS + for dtype in DTYPE_OPTS + if blocked_a or blocked_b or not prepack +] + [encode_torch_provider(single_thread, dtype) for dtype in DTYPE_OPTS for single_thread in SINGLE_THREAD_OPTS] +LINE_NAMES = LINE_VALS +LINE_STYLES = None + +default_num_threads = torch.get_num_threads() + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'matmul-performance-{DTYPE} (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, K, provider): + + device = 'cpu' if 'cpu' in provider else 'cuda' + backend, blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype = decode_provider(provider) + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((K, N), device=device, dtype=dtype) + + if single_thread: + torch.set_num_threads(1) + else: + torch.set_num_threads(default_num_threads) + + if backend == 'triton-cpu': + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + a_tmp = torch.zeros((M * K + (M // BLOCK_SIZE_M) * (K // BLOCK_SIZE_K) * 64), device=device, dtype=dtype) + b_tmp = torch.zeros((K * N + (K // BLOCK_SIZE_K) * (N // BLOCK_SIZE_N) * 64), device=device, dtype=dtype) + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + if prepack and (blocked_a or blocked_b): + grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) + block_transpose_combined_kernel[grid]( + a, a_tmp, b, b_tmp, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=blocked_a, TRANSPOSED_BLOCK_A=transposed_a, # + BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b) + if blocked_a: + a = a_tmp + if blocked_b: + b = b_tmp + else: + c = torch.zeros((M, N), device=a.device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + if backend == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles) + elif backend == 'torch-cpu-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) + elif backend == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, + num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000) + perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/python/tutorials/matrix-vector-multiplication-bf16.py b/python/tutorials/matrix-vector-multiplication-bf16.py new file mode 100644 index 000000000000..9927d2be956a --- /dev/null +++ b/python/tutorials/matrix-vector-multiplication-bf16.py @@ -0,0 +1,192 @@ +import torch + +import triton +import triton.language as tl + +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 64 +USE_GPU = False +""" +Kernel for computing Y = A @ X, where A is a dense matrix with +M rows and N columns. +- Input X has shape (N,) +- A has shape (M, N) +- Output has shape (M,) +""" + + +@triton.jit +def gemv_kernel( + Y, + A, + X, + M, + N, + stride_am, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + start_m = tl.program_id(0) + rm = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = tl.arange(0, BLOCK_SIZE_N) + + A = A + (rm[:, None] * stride_am + rn[None, :]) + X = X + rn + + acc = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + for n in range(N, 0, -BLOCK_SIZE_N): + a = tl.load(A) + x = tl.load(X) + acc += tl.sum(a * x[None, :], axis=1) + A += BLOCK_SIZE_N + X += BLOCK_SIZE_N + + y = acc.to(tl.bfloat16) + Y = Y + rm + tl.store(Y, y) + + +def gemv( + weight: torch.Tensor, + x: torch.Tensor, + output: torch.Tensor, + num_threads=0, +): + assert weight.shape[1] == x.shape[0], "Incompatible dimensions" + assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + + M, N = weight.shape + + # TODO: Currently masked load is not supported yet. + assert M % BLOCK_SIZE_M == 0 and N % BLOCK_SIZE_N == 0, "Masking currently not supported, Matrix dimensions must be multiples of block size" + + if output is None: + # Allocates output. + output = torch.empty((M, ), device=x.device, dtype=x.dtype) + else: + assert output.shape == (M, ) and output.dtype == x.dtype, "Incompatible output" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + num_threads=num_threads) + + return output + + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +weight = torch.randn((512, 1024), device='cpu', dtype=torch.bfloat16) +x = torch.randn((1024), device='cpu', dtype=torch.bfloat16) +triton_output = gemv(weight, x, None) +compiled_matmul = torch.compile(torch.matmul) +# Note: torch.matmul for bf16 on Arm Linux will trigger error on old torch versions: +# RuntimeError: could not create a primitive descriptor for a matmul primitive +# So we recommend using torch 2.4.0 onwards. +torch_output = torch.matmul(weight, x) +#print(f"triton_cpu_output_with_{weight.dtype}_inputs={triton_output}") +#print(f"torch_cpu_output_with_{weight.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + +LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu', 'torch-cpu-native-single', 'torch-cpu-native', 'torch-cpu-compile-single', + 'torch-cpu-compile' +] +LINE_NAMES = [ + 'TritonCPU 1', 'TritonCPU', 'TorchCPU (native) 1', 'TorchCPU (native)', 'TorchCPU (compile) 1', 'TorchCPU (compile)' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-'), ('red', '--'), ('red', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + weight = weight.to('cuda') + x = x.to('cuda') + triton_output = gemv(weight, x, None) + torch_output = torch.matmul(weight, x) + #print(f"triton_gpu_output_with_{weight.dtype}_inputs={triton_output}") + #print(f"torch_gpu_output_with_{weight.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('pink', '-'), ('cyan', '-')] + +default_num_threads = torch.get_num_threads() + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N"], # Argument names to use as an x-axis for the plot + x_vals=[(512 * i, 4096) for i in range(10, 51, 4)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'gemv-performance-bf16 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, provider): + + device = 'cpu' if 'cpu' in provider else 'cuda' + weight = torch.randn((M, N), device=device, dtype=torch.bfloat16) + x = torch.randn((N), device=device, dtype=torch.bfloat16) + + if device == 'cpu': + output = torch.empty((M), device=x.device, dtype=x.dtype) + triton.runtime.driver.set_active_to_cpu() + num_threads = 0 + if 'single' in provider: + num_threads = 1 + torch.set_num_threads(1) + else: + torch.set_num_threads(default_num_threads) + else: + output = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + elif 'torch-cpu-native' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles) + elif 'torch-cpu-compile' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output), + quantiles=quantiles) + elif 'triton-cpu' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=num_threads), + quantiles=quantiles) + + perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/python/tutorials/matrix-vector-multiplication.py b/python/tutorials/matrix-vector-multiplication.py new file mode 100644 index 000000000000..06feca82893f --- /dev/null +++ b/python/tutorials/matrix-vector-multiplication.py @@ -0,0 +1,204 @@ +import torch + +import triton +import triton.language as tl + +BLOCK_SIZE_M = 1 +BLOCK_SIZE_N = 512 +USE_GPU = False +""" +Kernel for computing Y = A @ X, where A is a dense matrix with +M rows and N columns. +- Input X has shape (N,) +- A has shape (M, N) +- Output has shape (M,) +""" + + +@triton.jit +def gemv_kernel( + Y, + A, + X, + M, + N, + stride_am, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + start_m = tl.program_id(0) + rm = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = tl.arange(0, BLOCK_SIZE_N) + + A = A + (rm[:, None] * stride_am + rn[None, :]) + X = X + rn + + acc = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + for n in range(N, 0, -BLOCK_SIZE_N): + a = tl.load(A) + x = tl.load(X) + acc += tl.sum(a * x[None, :], axis=1) + A += BLOCK_SIZE_N + X += BLOCK_SIZE_N + + Y = Y + rm + tl.store(Y, acc) + + +def gemv( + weight: torch.Tensor, + x: torch.Tensor, + output: torch.Tensor, + num_threads=0, +): + assert weight.shape[1] == x.shape[0], "Incompatible dimensions" + assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + + M, N = weight.shape + + # TODO: Currently masked load is not supported yet. + assert M % BLOCK_SIZE_M == 0 and N % BLOCK_SIZE_N == 0, "Masking currently not supported, Matrix dimensions must be multiples of block size" + + if output is None: + # Allocates output. + output = torch.empty((M, ), device=x.device, dtype=x.dtype) + else: + assert output.shape == (M, ) and output.dtype == x.dtype, "Incompatible output" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + num_threads=num_threads) + + return output + + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +weight = torch.randn((512, 1024), device='cpu', dtype=torch.float32) +x = torch.randn((1024), device='cpu', dtype=torch.float32) +triton_output = gemv(weight, x, None) +# torch.matmul will select bf16 kernels on Linux Arm if x is 1-d, which has lower precision. +# So we reshape x to be 2-d, which will invoke different kernels. +torch_output = torch.matmul(weight, x[:, None]).reshape(-1) +#print(f"triton_cpu_output_with_{weight.dtype}_inputs={triton_output}") +#print(f"torch_cpu_output_with_{weight.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + +LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu', 'triton-cpu-linear', 'torch-cpu-native', 'torch-cpu-compile', + 'torch-cpu-2d-native', 'torch-cpu-2d-compile', 'torch-cpu-transpose-native', 'torch-cpu-transpose-compile', + 'torch-cpu-linear' +] +LINE_NAMES = [ + 'TritonCPU 1', 'TritonCPU', 'TritonCPU Linear', 'TorchCPU (native)', 'TorchCPU (compile)', 'TorchCPU 2D (native)', + 'TorchCPU 2D (compile)', 'TorchCPU Transpose (native)', 'TorchCPU Transpose (compile)', 'TorchCPU Linear' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', ':'), ('green', '--'), ('green', '-'), ('red', '--'), + ('red', '-'), ('yellow', '--'), ('yellow', '-'), ('purple', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + weight = weight.to('cuda') + x = x.to('cuda') + triton_output = gemv(weight, x, None) + torch_output = torch.matmul(weight, x) + #print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") + #print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('pink', '-'), ('cyan', '-')] + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["N"], # Argument names to use as an x-axis for the plot + x_vals=[512 * i for i in range(10, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'gemv-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N})', + args={'M': 4096}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, provider): + + device = 'cpu' if 'cpu' in provider else 'cuda' + weight = torch.randn((M, N), device=device, dtype=torch.float32) + x = torch.randn((N), device=device, dtype=torch.float32) + + if device == 'cpu': + output = torch.empty((M), device=x.device, dtype=x.dtype) + triton.runtime.driver.set_active_to_cpu() + + if 'transpose' in provider: + weight = torch.transpose(weight, 0, 1) + x = x[None, :] + output = output[None, :] + elif '2d' in provider: + x = x[:, None] + output = output[:, None] + else: + output = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + elif provider == 'torch-cpu-native' or provider == 'torch-cpu-2d-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles) + elif provider == 'torch-cpu-compile' or provider == 'torch-cpu-2d-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles) + elif provider == 'torch-cpu-transpose-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles) + elif provider == 'torch-cpu-transpose-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles) + elif provider == 'torch-cpu-linear': + weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles) + elif provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=1), + quantiles=quantiles) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + elif provider == 'triton-cpu-linear': + # torch.nn.Linear.forward does not take preallocated output buffer, so we also do no provide output buffer for fair comparison + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles) + perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/raw-to-csv.py b/raw-to-csv.py new file mode 100755 index 000000000000..4be687ff1a01 --- /dev/null +++ b/raw-to-csv.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python + +import sys +import re + +def convert_all(in_file): + metadata = None + csv_lines = [] + + for line in in_file: + if line.startswith("RUN") or line.startswith("BENCHMARK"): + # If this is not the first RUN line, yield the previous bundle + if metadata and csv_lines: + yield (csv_lines, metadata) + metadata = None + csv_lines = [] + metadata = {'line': line} + if m := re.match(r" +M.*", line): # header row + csv_line = re.subn("[, ][, ]+", ",", line) + l = csv_line[0][1:].split(",") + k_onwards = l[2:] + csv_lines += [",".join(k_onwards)] + if m := re.match(r"\d+ +(\d.*)", line): # data row + csv_line = re.subn("[, ]+", ",", m[1]) + l = csv_line[0][1:].split(",") + k_onwards = l[2:] + csv_lines += [",".join(k_onwards)+"\n"] + + # The last bundle gets yielded here + if metadata and csv_lines: + yield (csv_lines, metadata) + + +if __name__ == "__main__": + in_file = sys.stdin if (len(sys.argv) < 2 or sys.argv[1] == '-') else open(sys.argv[1]) + + converted = convert_all(in_file) + + for csv_lines, metadata in converted: + print(metadata['line'], end='') + for csv_line in csv_lines: + print(csv_line, end='') diff --git a/run_all_benchmarks.sh b/run_all_benchmarks.sh new file mode 100755 index 000000000000..3987ecda678b --- /dev/null +++ b/run_all_benchmarks.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +source ../miniforge/bin/activate triton + +THREADS=$(lscpu | grep --color=never "Core.*socket" | grep -o "[0-9]\+") + +commit=$(git describe --always --dirty) + +for datatype in f32 bf16 bf8; do + for num_threads in 1 $THREADS; do + for backend in torch-cpu-compile torch-cpu-native; do + config=baseline + echo -e "BENCHMARK: {'backend': '$backend', 'config': '$config', 'threads': $num_threads, 'type': '$datatype', 'commit': '$commit'}" + time $SCRIPT_DIR/python/tutorials/03-matrix-multiplication-cpu.sh $config $num_threads --datatype $datatype --backend $backend 2>&1 + echo -e "\n\n" + done + + backend=triton-cpu + for config in baseline-scalar baseline-block; do + echo -e "BENCHMARK: {'backend': '$backend', 'config': '$config', 'threads': $num_threads, 'type': '$datatype', 'commit': '$commit'}" + time $SCRIPT_DIR/python/tutorials/03-matrix-multiplication-cpu.sh $config $num_threads --datatype $datatype --backend $backend 2>&1 + echo -e "\n\n" + done + + # Triton-XSMM + backend=triton-xsmm + for config in xsmm-scalar xsmm-block; do + echo -e "BENCHMARK: {'backend': '$backend', 'config': '$config', 'threads': $num_threads, 'type': '$datatype', 'commit': '$commit'}" + time $SCRIPT_DIR/python/tutorials/03-matrix-multiplication-cpu.sh $config $num_threads --datatype $datatype --backend $backend 2>&1 + echo -e "\n\n" + done + for config in xsmm-pad-k xsmm-loop-collapse-pad-b; do + for external_pad in "" "--external-pad"; do + echo -e "BENCHMARK: {'backend': '$backend', 'config': '$config', 'threads': $num_threads, 'type': '$datatype', 'pad': '$external_pad', 'commit': '$commit'}" + time $SCRIPT_DIR/python/tutorials/03-matrix-multiplication-cpu.sh $config $num_threads --datatype $datatype --backend $backend $external_pad 2>&1 + echo -e "\n\n" + done + done + done +done + +conda deactivate diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 4f3af58e85f2..adea6e3b6688 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -1,15 +1,15 @@ // RUN: triton-opt %s --mlir-disable-threading -test-print-alias -split-input-file 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: matmul_loop // CHECK-NOT: -> @@ -26,9 +26,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> @@ -41,7 +41,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: %0 -> %0 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } @@ -49,40 +49,40 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @alloc_init(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %0 - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: subview -tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { +tt.func @subview(%A : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory>) { %index = arith.constant 0 : i32 // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %0 - %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.memdesc_subview %a[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: if_alias tt.func @if_alias(%i1 : i1) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %0,%1 - %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { - scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = scf.if %i1 -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> { + scf.yield %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { - scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -90,11 +90,11 @@ tt.func @if_alias(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: %2 -> %2 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg6 -> %0 // CHECK-NEXT: %arg7 -> %1 // CHECK-NEXT: %arg8 -> %2 @@ -102,8 +102,8 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> - (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -111,11 +111,11 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -123,14 +123,14 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 // CHECK-NEXT: %4 -> %0,%1 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -138,11 +138,11 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -150,23 +150,23 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %3#1 -> %1 // CHECK-NEXT: %3#2 -> %2,%6,%6 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { // CHECK-NEXT: %arg11 -> %2,%6,%6 // CHECK-NEXT: %4 -> %2,%6,%6 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { // CHECK-NEXT: %5 -> %6,%6 - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -175,29 +175,29 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { %idx = arith.constant 0 : i32 // CHECK: %0 -> %0 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %2 -> %0 - %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> gpu.barrier // CHECK-NEXT: %3 -> %3 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: %5 -> %0,%1,%3 // CHECK-NEXT: %6 -> %0,%1,%3 // CHECK-NEXT: %7 -> %0,%1,%3 - cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) -^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %3: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %4: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) +^bb1(%1: index, %2: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %3: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %4: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 gpu.barrier %8 = arith.addi %1, %arg2 : index - cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) + cf.br ^bb1(%8, %4, %2, %3 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) ^bb3: // pred: ^bb1 gpu.barrier // CHECK-NEXT: %10 -> %0 - %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %9 = ttg.memdesc_subview %0[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a0719c974f9c..a4dfb20bcbe8 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,21 +1,28 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +// Check there are no lines with a size different to 128 and we have at least a line with size 128. -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} +// CHECK-128: scratch offset = {{.*}}, size = 128 +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: empty tt.func @empty(%A : !tt.ptr) { %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> tt.return // CHECK: size = 0 } @@ -37,10 +44,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> // CHECK: offset = 0, size = 4608 - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> // CHECK-NEXT: offset = 0, size = 4352 - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -65,17 +72,17 @@ tt.func @reusable(%A : !tt.ptr) { %b_ptr = tt.splat %A : !tt.ptr -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 - %a1 = triton_gpu.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a1 = ttg.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 1088 - %a2 = triton_gpu.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %a2 = ttg.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 - %a3 = triton_gpu.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a3 = ttg.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 1088 - %a4 = triton_gpu.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %a4 = ttg.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return // CHECK-NEXT: size = 4608 @@ -88,47 +95,47 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate tt.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 6144, size = 2048 - %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %e = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 2048 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %d = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %b : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 10240, size = 2048 - %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %f = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %c : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %g = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %e : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %h = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %d : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %i = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %f : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst5 : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 12288 } @@ -138,11 +145,11 @@ tt.func @preallocate(%A : !tt.ptr) { tt.func @unused(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK: size = 1024 } @@ -151,33 +158,33 @@ tt.func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst6 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %d = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 4096 } @@ -186,43 +193,43 @@ tt.func @longlive(%A : !tt.ptr) { // CHECK-LABEL: multi_color tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1536, size = 32 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1664, size = 128 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> // CHECK-NEXT: offset = 0, size = 256 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 64 - %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> - %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %cst_5 = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> + %4 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1792, size = 128 - %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_8 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 32 - %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_9 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> - %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x32xf16, #AL> + %cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> - %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> + %10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> // CHECK-NEXT: size = 1920 @@ -233,25 +240,25 @@ tt.func @multi_color(%A : !tt.ptr) { // CHECK-LABEL: multi_color_multi_rounds tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1280, size = 128 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 8192 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1152, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> - %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<1024x4xf16, #AL> + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> - %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> + %5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> // CHECK-NEXT: size = 10240 tt.return } @@ -260,10 +267,10 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -272,10 +279,10 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-LABEL: dealloc tt.func @dealloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: offset = 1024, size = 1024 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 2048 } @@ -296,8 +303,8 @@ tt.func @scratch() { // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> tt.return } @@ -305,9 +312,9 @@ tt.func @trans(%A : !tt.ptr) { // CHECK-LABEL: extract_slice tt.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %index = arith.constant 0 : i32 - %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.memdesc_subview %cst0[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -319,9 +326,9 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { // CHECK: size = 8196 %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return %4 : i32 } @@ -331,9 +338,9 @@ tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { // CHECK: size = 8192 %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -342,25 +349,25 @@ tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { // CHECK-LABEL: if tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 3072 } @@ -370,28 +377,28 @@ tt.func @if(%i1 : i1) { // CHECK-LABEL: if_else tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 5120 } @@ -401,13 +408,13 @@ tt.func @if_else(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -416,18 +423,18 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_slice tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -437,16 +444,16 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK-LABEL: for_use_ancestor tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c0 = ttg.memdesc_trans %c_shared_init {order=array} : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #A_SHARED_T, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 24576, size = 8192 - %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %c1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %b_shared, %a_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 32768 @@ -457,40 +464,40 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst1 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 40960 } } -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-warps" = 4 : i32} { // CHECK-LABEL: alloc1 tt.func @alloc1(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -498,7 +505,7 @@ tt.func @alloc1(%A : !tt.ptr) { // CHECK-LABEL: alloc2 tt.func @alloc2(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return // CHECK-NEXT: size = 1024 } @@ -507,10 +514,10 @@ tt.func @alloc2(%A : !tt.ptr) { tt.func @alloc3(%cond : i1) { scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } else { // CHECK-NEXT: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 1024 @@ -532,7 +539,7 @@ tt.func @alloc4(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: single_call tt.func @single_call(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () @@ -543,7 +550,7 @@ tt.func @single_call(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -558,9 +565,9 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst1 = ttg.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () } else { @@ -575,7 +582,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -591,7 +598,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc3(%cond) : (i1) -> () tt.return @@ -601,7 +608,7 @@ tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_2 tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () tt.return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 2054853b30c1..29e0b253be02 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,15 +1,15 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf --allocate-shared-memory -test-print-membar 2>&1 | FileCheck %s -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: matmul_loop // There shouldn't be any membar with the dot op encoding. @@ -28,9 +28,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> @@ -46,10 +46,10 @@ tt.func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -59,14 +59,14 @@ tt.func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - // CHECK: triton_gpu.local_alloc + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.local_alloc // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: %4 = triton_gpu.local_alloc - %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: %4 = ttg.local_alloc + %4 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> tt.return } @@ -76,25 +76,25 @@ tt.func @war_single_block_local_store(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_alloc + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_alloc // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_store - triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttg.local_store + ttg.local_store %1, %2 : tensor<128x32xf16, #AL> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> tt.return } // CHECK-LABEL: scratch tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load + // CHECK-NEXT: ttg.local_load // CHECK: gpu.barrier // CHECK: tt.reduce - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f16, %arg2: f16): %add = arith.addf %arg1, %arg2 : f16 @@ -105,34 +105,34 @@ tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { // CHECK-LABEL: async_wait tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - // CHECK: triton_gpu.async_wait - triton_gpu.async_wait {num = 4 : i32} + %cst0 = ttg.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.async_wait + ttg.async_wait {num = 4 : i32} // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<32x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<32x16xf16, #AL> tt.return } // CHECK-LABEL: subview tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> - %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> %index = arith.constant 0 : i32 - %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %0 = ttg.memdesc_subview %a[%index, %index] : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } // CHECK-LABEL: trans -tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { +tt.func @trans(%a: !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> + %b = ttg.memdesc_trans %a {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory> tt.return } @@ -142,31 +142,31 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %subview = ttg.memdesc_subview %alloc[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %subview : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks tt.func @multi_blocks(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -174,21 +174,21 @@ tt.func @multi_blocks(%i1 : i1) { // CHECK-LABEL: multi_blocks_join_barrier tt.func @multi_blocks_join_barrier(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK-NOT: gpu.barrier // CHECK: tt.return - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -196,25 +196,25 @@ tt.func @multi_blocks_join_barrier(%i1 : i1) { // CHECK-LABEL: multi_blocks_yield tt.func @multi_blocks_yield(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - // CHECK: triton_gpu.local_load + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -222,27 +222,27 @@ tt.func @multi_blocks_yield(%i1 : i1) { // CHECK-LABEL: multi_blocks_entry_no_shared tt.func @multi_blocks_entry_no_shared(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc + // CHECK-NEXT: ttg.local_alloc // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load + // CHECK-NEXT: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %0 = ttg.local_load %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK-NOT: gpu.barrier - // CHECK: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -250,16 +250,16 @@ tt.func @multi_blocks_entry_no_shared(%i1 : i1) { // CHECK-LABEL: multi_blocks_noelse tt.func @multi_blocks_noelse(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -267,39 +267,39 @@ tt.func @multi_blocks_noelse(%i1 : i1) { // CHECK-LABEL: multi_blocks_nested_scf tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> scf.if %i1 { scf.if %i2 { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } scf.yield } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -309,24 +309,24 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_alias tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %a1 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -335,63 +335,63 @@ tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_reuse tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for_reuse_nested tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -399,25 +399,25 @@ tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } else { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -426,30 +426,30 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: for_if_for tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier - %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %c_blocked = ttg.local_load %c_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } else { - %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + // CHECK-NEXT: ttg.local_load + %c_blocked_next = ttg.local_load %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } - scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_ : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_blocked_next = ttg.local_load %b_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> } tt.return } @@ -457,65 +457,65 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: cf_if_else tt.func @cf_if_else(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) -^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb1, ^bb2 + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) +^bb3(%arg: !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 - // CHECK: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %5 = ttg.local_load %arg : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: cf_if_else_return tt.func @cf_if_else_return(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> - %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %b = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -524,9 +524,9 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { // CHECK-NOT: gpu.barrier %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return %4 : i32 } @@ -534,53 +534,53 @@ tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> tt.return } } -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: convert_layout1 tt.func @convert_layout1(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout2 tt.func @convert_layout2(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK: triton_gpu.local_load - %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> - %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %4 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout3 tt.func @convert_layout3(%cond : i1) { scf.if %cond { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NOT: gpu.barrier - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #AL> + %1 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x64xf16, #AL> } else { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - // CHECK: triton_gpu.local_load + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_gpu.local_alloc - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttg.local_alloc + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> } tt.return } @@ -602,7 +602,7 @@ tt.func @single_call_sync(%A : !tt.ptr) { // CHECK: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %1 = triton_gpu.convert_layout %0 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.convert_layout %0 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> tt.return } @@ -612,14 +612,14 @@ tt.func @single_call_no_sync(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier %0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #BL> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #BL> tt.return } // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.call @convert_layout1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () @@ -631,12 +631,12 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { scf.if %cond { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst1 = ttg.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> } else { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: tt.call @@ -649,7 +649,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -665,7 +665,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call tt.call @convert_layout3(%cond) : (i1) -> () tt.return @@ -677,7 +677,7 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () // CHECK: tt.call // CHECK-NEXT: gpu.barrier - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> tt.return } @@ -685,28 +685,28 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { tt.func public @kernel(%arg3: !tt.ptr, %arg4: !tt.ptr, %arg12: tensor<32x128xf16, #blocked>, %arg13: tensor<32x128xf32, #blocked>, %arg14: tensor<32x32xf16, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> - %37 = triton_gpu.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !tt.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> - %58 = triton_gpu.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> + %37 = ttg.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> + %58 = ttg.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> cf.br ^bb1 ^bb1: // 2 preds: ^bb0, ^bb1 %59 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 %60 = arith.cmpi eq, %59, %c0_i32 : i32 cf.cond_br %60, ^bb1, ^bb2 ^bb2: // pred: ^bb1 - %72 = triton_gpu.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> - %73 = triton_gpu.local_load %37 : !tt.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %74 = triton_gpu.local_load %58 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> - %76 = triton_gpu.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> + %72 = ttg.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> + %73 = ttg.local_load %37 : !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %74 = ttg.local_load %58 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> + %76 = ttg.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> %77 = arith.truncf %76 : tensor<32x128xf32, #blocked> to tensor<32x128xf16, #blocked> %78 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> tt.store %78, %77 : tensor<32x128x!tt.ptr, #blocked> @@ -716,54 +716,54 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %c0 = arith.constant 0 : i32 - %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - // CHECK: triton_nvidia_gpu.init_barrier - // CHECK-NEXT: triton_nvidia_gpu.init_barrier - triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + // CHECK: ttng.init_barrier + // CHECK-NEXT: ttng.init_barrier + ttng.init_barrier %barrier, 1 : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.init_barrier %barrier, 1 : <1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.wait_barrier + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> - // CHECK-NEXT: triton_gpu.local_load - %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: ttng.barrier_expect // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> // CHECK-NEXT: gpu.barrier - // CHECK-NEXT: triton_nvidia_gpu.inval_barrier - // CHECK-NEXT: triton_nvidia_gpu.inval_barrier - triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + // CHECK-NEXT: ttng.inval_barrier + // CHECK-NEXT: ttng.inval_barrier + ttng.inval_barrier %barrier : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.inval_barrier %barrier : <1xi64, #shared1, #ttg.shared_memory, mutable> tt.return %t : tensor<256x64xf16, #blocked> } @@ -771,38 +771,38 @@ tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocke // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { // CHECK-LABEL: tma_special_cases_cf tt.func @tma_special_cases_cf(%arg1: !tt.ptr, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ %true = arith.constant 1 : i1 %c0 = arith.constant 0 : i32 - %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> // CHECK: cf.cond_br scf.if %i1 { // CHECK-NOT: gpu.barrier - // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local - // CHECK-NEXT: triton_nvidia_gpu.barrier_expect - // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + // CHECK: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.wait_barrier // CHECK-NEXT: cf.br - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #ttg.shared_memory, mutable> -> <256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : <1xi64, #shared1, #ttg.shared_memory, mutable> scf.yield } else { // CHECK-NOT: gpu.barrier - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK-NEXT: cf.br - triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + ttg.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> scf.yield } // CHECK: gpu.barrier - // CHECK-NEXT: triton_gpu.local_load - %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> tt.return %t : tensor<256x64xf16, #blocked> } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8028d099fe37..57397810efbb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,6 +13,7 @@ configure_lit_site_cfg( set(TRITON_TEST_DEPENDS triton-opt + triton-tensor-layout ) set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck") diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir new file mode 100644 index 000000000000..345714f5b2b3 --- /dev/null +++ b/test/Conversion/allocate_shared_memory.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt %s --allocate-shared-memory | FileCheck %s + +// CHECK-LABEL: module +// CHECK-SAME: ttg.shared = 131072 : i32 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @gather_op +// TODO(jeff): Optimize the lowering to reduce shared memory usage. +tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { + // CHECK-NEXT: allocation.offset = 0 : i32 + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> + tt.return +} + +} diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir index 209c7065d8a6..70abc555945f 100644 --- a/test/Conversion/amd/buffer_load_store.mlir +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load tt.func @buffer_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 @@ -14,8 +14,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_mask tt.func @buffer_load_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -36,8 +36,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_mask_other tt.func @buffer_load_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -60,8 +60,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_store tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 @@ -74,8 +74,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_store_mask tt.func @buffer_store_mask(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { %c256_i32 = arith.constant 256 : i32 @@ -97,8 +97,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec4 tt.func @buffer_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -123,8 +123,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec1 tt.func @buffer_load_store_vec1(%arg0: !tt.ptr , %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -151,8 +151,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: buffer_load_store_vec2 tt.func @buffer_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr{tt.divisibility = 4 : i32}, %arg2: !tt.ptr{tt.divisibility = 4: i32}, %arg3: i32{tt.divisibility = 4: i32}) { %c256_i32 = arith.constant 256 : i32 diff --git a/test/Conversion/amd/builtin_func_to_llvm.mlir b/test/Conversion/amd/builtin_func_to_llvm.mlir new file mode 100644 index 000000000000..6458817302ab --- /dev/null +++ b/test/Conversion/amd/builtin_func_to_llvm.mlir @@ -0,0 +1,12 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", pure = true, symbol = "__triton_hip_fast_expf"} : (tensor<64xf32, #blocked>) -> tensor<64xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/amd/compute-base-ptr.mlir b/test/Conversion/amd/compute-base-ptr.mlir index e8376b1d8bf7..4c74e95d8ad0 100644 --- a/test/Conversion/amd/compute-base-ptr.mlir +++ b/test/Conversion/amd/compute-base-ptr.mlir @@ -1,18 +1,20 @@ -// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @local_load_offset tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { - %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> - %1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) + %1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> loc(#loc2) // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. - // CHECK: llvm.sub - // CHECK-NEXT: llvm.getelementptr - // CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 - %2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0 + %2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) tt.return } } +#loc1 = loc("conert_layout":1:0) +#loc2 = loc("local_alloc":2:0) +#loc3 = loc("local_load":3:0) diff --git a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir index f30e0aa6d98e..848e13118e85 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir @@ -1,33 +1,33 @@ // RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s -// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> -// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> // CHECK: large_tensor_conversion -#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> -#dst = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) { - // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> - // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> - %0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> tt.return } } // ----- -// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> -// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> // CHECK: large_tensor_3d_conversion -#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> -#dst = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) { - // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> - // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> - %0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> tt.return } } diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index 1bd288449f28..983d16e8d6bf 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,33 +1,35 @@ // RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s -// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot_op -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #smem> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } // ----- -// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot3d_op -#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> - %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #smem> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } @@ -35,13 +37,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.local_alloc - // CHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.local_alloc - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -49,13 +51,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.local_alloc - // CHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.local_alloc - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -63,13 +65,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -77,14 +79,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> tt.return } } @@ -92,14 +94,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: triton_gpu.local_alloc - // CHECK: triton_gpu.local_load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> tt.return } } diff --git a/test/Conversion/amd/dedup-by-constancy.mlir b/test/Conversion/amd/dedup-by-constancy.mlir index 8340cce6d151..66a224bcefb2 100644 --- a/test/Conversion/amd/dedup-by-constancy.mlir +++ b/test/Conversion/amd/dedup-by-constancy.mlir @@ -13,13 +13,13 @@ // only allows duplication within each group of 4 elemnets. Therefore, we expect 4 icmp, one // for each group of 4 elements. // In the future, we can reduce the icmp to 2 in such case. -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @dedup_by_constancy_mfma(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %1 = tt.splat %arg0 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %2 = arith.cmpi slt, %0, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi1, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %1 = tt.splat %arg0 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %2 = arith.cmpi slt, %0, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi1, #mma> %4 = tt.broadcast %3 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma> %cst = arith.constant dense<0.100000e+00> : tensor<32x32xf16, #mma> %5 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #mma> diff --git a/test/Conversion/amd/fp_to_fp.mlir b/test/Conversion/amd/fp_to_fp.mlir index aaa70564fd79..959158ab49e1 100644 --- a/test/Conversion/amd/fp_to_fp.mlir +++ b/test/Conversion/amd/fp_to_fp.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s // CHECK-LABEL: f16_to_f32 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { // CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}v_cvt_f32_f16 {{.*}}: (f16) -> f32 - %0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -13,11 +13,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: bf16_to_f32 -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>) { +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) { // CHECK-COUNT-8: llvm.bitcast - %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> tt.return } } diff --git a/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir new file mode 100644 index 000000000000..9730f9eace72 --- /dev/null +++ b/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -0,0 +1,111 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + +// Invalid size +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero source dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{source tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero result dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, not multiple of shapePerTile +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, out of bounds for dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{invalid offset 128 at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result layout +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result layout must match source layout}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> + tt.return +} + +// ----- + +// Invalid result element type +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result element type must match source element type}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> + tt.return +} + +// ----- + +// Invalid result rank +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result rank must be equal to source rank}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result shape +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid rank +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{currently only 2D tensors are supported}} + %1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid non static offset +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+2 {{expected ']'}} + // expected-error @+1 {{expected integer value}} + %2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} diff --git a/test/Conversion/amd/load_store.mlir b/test/Conversion/amd/load_store.mlir index 93796439b012..25336d25528f 100644 --- a/test/Conversion/amd/load_store.mlir +++ b/test/Conversion/amd/load_store.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Load 8 elements from A with two vectorized load instruction - // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32> + // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32> %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> // Load 8 elements from B with two vectorized load instruction - // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32> + // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32> %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> @@ -27,3 +27,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: global_store_mfma_vec16 + tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %1 = math.exp2 %0 : tensor<32x32xf32, #mma> + %2 = arith.truncf %1 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %c32_i32 = arith.constant 32 : i32 + %100 = tt.get_program_id x : i32 + %101 = arith.muli %100, %c32_i32 : i32 + %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> + %200 = tt.broadcast %300 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma> + %103 = tt.splat %101 : i32 -> tensor<32x32xi32, #mma> + %104 = arith.addi %103, %200 : tensor<32x32xi32, #mma> + %105 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + %106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr, #mma>, tensor<32x32xi32, #mma> + // Store 16 elements with four vectorized store instruction + // CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr<1> + tt.store %106, %2 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir index 520f44db933d..86c08ca2ae2b 100644 --- a/test/Conversion/amd/math-denorm-handling.mlir +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -2,8 +2,8 @@ // RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { // LLVM_FTZ: llvm.amdgcn.exp2.f32 // LLVM_NO_FTZ: llvm.exp2.f32 @@ -14,8 +14,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { // LLVM_FTZ: llvm.exp2.f32 // LLVM_NO_FTZ: llvm.exp2.f32 diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 83c9e535d8c0..9a9764d992e3 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,27 +1,29 @@ // RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: shortcut_mfma16 tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } } // ----- -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: no_shortcut_mfma16 tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } } diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 876dc0d76982..8f4fbee399b4 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.cond_br @@ -18,19 +18,194 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.cond_br // CHECK: llvm.atomicrmw // CHECK: llvm.atomicrmw - // CHECK: %[[ADDR1:.*]] = llvm.addrspacecast - // CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]] - // CHECK: %[[ADDR2:.*]] = llvm.addrspacecast - // CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]] + // CHECK: llvm.intr.masked.store + // CHECK: llvm.intr.masked.store %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> tt.store %arg0, %0 : tensor<256x!tt.ptr, #blocked0> tt.return } } + +// ----- + +// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16 +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +#dotop1 = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: small_mfma_tensor_conversions + tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr, #mfma>) { + // CHECK-NOT: ttg.convert_layout + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + // CHECK-4: store {{.*}} vector<4xf16> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop0> + // CHECK-2: load {{.*}} vector<4xf16> + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop1> + // CHECK-8: load {{.*}} vector<1xf16> + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #mfma> + // CHECK-4: load {{.*}} vector<4xf16> + %4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma> + + %5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma> + // Store result to prevent DCE from removing all conversion related code + %6 = ttg.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !ttg.memdesc<16x16xf32, #shared, #smem> + tt.return + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16x2 + tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16x2 + tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16_dpp + tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16_dpp + tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked3 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_dpp_max + tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 274, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 273, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 322, 10, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 323, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK: llvm.amdgcn.readlane + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64xf32, #blocked3>) -> f32 + tt.return + } +} + +// ----- + +#blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_xor_max + tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { + // CHECK: rocdl.ds_swizzle + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 12, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 78, 15, 15, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 177, 15, 15, false : i32 + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32xf32, #blocked4>) -> f32 + tt.return + } +} diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index 5eb856bb9952..68eb76afdb72 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,37 +1,39 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma1_dot_operand - tt.func @wmma1_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + tt.func @wmma1_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma2_dot_operand - tt.func @wmma2_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + tt.func @wmma2_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> tt.return } // CHECK-LABEL: wmma1_dot - tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { + tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<16xf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -39,7 +41,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: wmma1_dot_bf16 - tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { + tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> @@ -48,12 +50,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef : vector<16xbf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> tt.return } // CHECK-LABEL: wmma1_dot_int8_32 - tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> @@ -62,13 +64,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma1_dot_int4_32 - tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> @@ -77,13 +79,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } // CHECK-LABEL: wmma2_dot - tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { + tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> @@ -91,30 +93,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<8xf16> // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v8f16"{{.*}} : (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> tt.return } + + // CHECK-LABEL: blocked_to_wmma1 + tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma1 + tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> + tt.return + } + + // CHECK-LABEL: wmma1_to_blocked + tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma1_to_blocked + tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return + } + + // CHECK-LABEL: blocked_to_wmma2 + tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma2 + tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> + tt.return + } + + // CHECK-LABEL: wmma2_to_blocked + tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma2_to_blocked + tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return + } } // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> -#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_dot_operand3d - tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) { + tt.func @wmma_dot_operand3d(%arg0: !ttg.memdesc<4x16x32xf16, #shared, #smem>) { // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_dot3d - tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %arg0 // CHECK-COUNT-32: llvm.insertelement // CHECK-COUNT-32: llvm.extractvalue %arg1 @@ -122,7 +189,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-8: llvm.extractvalue %arg2 // CHECK-COUNT-8: llvm.insertelement // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement // CHECK-COUNT-8: llvm.insertvalue tt.return diff --git a/test/Conversion/dedup-by-constancy.mlir b/test/Conversion/dedup-by-constancy.mlir index 96131eae87f5..dc2cda84a763 100644 --- a/test/Conversion/dedup-by-constancy.mlir +++ b/test/Conversion/dedup-by-constancy.mlir @@ -10,8 +10,8 @@ // CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] // CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] // CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<256> : tensor<1024xi32, #blocked> %c1024_i32 = arith.constant 1024 : i32 @@ -48,8 +48,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] // CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] // CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<4> : tensor<1024xi32, #blocked> %c1024_i32 = arith.constant 1024 : i32 diff --git a/test/Conversion/divide-by-0.mlir b/test/Conversion/divide-by-0.mlir index 8f920fcc05f0..f12fd1bc78ca 100644 --- a/test/Conversion/divide-by-0.mlir +++ b/test/Conversion/divide-by-0.mlir @@ -3,12 +3,12 @@ // CHECK-LABEL: dont_divide_0 // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NOT: llvm.urem %{{.*}}, %[[C0]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @dont_divide_0() attributes {noinline = false} { %zero = arith.constant dense<0.000000e+00> : tensor<16x1xf32, #mma> - %cvt = triton_gpu.convert_layout %zero : tensor<16x1xf32, #mma> -> tensor<16x1xf32, #blocked> + %cvt = ttg.convert_layout %zero : tensor<16x1xf32, #mma> -> tensor<16x1xf32, #blocked> tt.return } } diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 2c9e8beb4199..28c6b6ba7e8f 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=cuda:80 num-warps=2' | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @ops() { - // CHECK: module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {{.*}} + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> @@ -13,7 +13,7 @@ tt.func @ops() { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if LoadOp is lowered properly (see #771) %ptrs = tt.splat %ptr : !tt.ptr -> tensor<128x!tt.ptr> @@ -34,36 +34,36 @@ tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if the total number of threadsPerWarp is 32 // Test if the total number of warps is 2 - // CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> - // CHECK: module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {{.*}} + // CHECK: #[[blocked0:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked1:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked2:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> - // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> + // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #ttg.slice<{dim = 0, parent = #[[blocked0]]}>> %c0_ = "tt.reduce" (%c0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32> - // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #ttg.slice<{dim = 0, parent = #[[blocked1]]}> %c1_ = "tt.reduce" (%c1) ({ ^bb0(%arg3: f32, %arg4: f32): %add = arith.addf %arg3, %arg4 : f32 tt.reduce.return %add : f32 }) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32> - // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #ttg.slice<{dim = 1, parent = #[[blocked1]]}>> %c2_ = "tt.reduce" (%c1) ({ ^bb0(%arg5: f32, %arg6: f32): %add = arith.addf %arg5, %arg6 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32> - // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> + // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[blocked2]]}>> %c3_ = "tt.reduce" (%c2) ({ ^bb0(%arg7: f32, %arg8: f32): %add = arith.addf %arg7, %arg8 : f32 @@ -77,7 +77,7 @@ tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i1) attributes {noinline = false} { // CHECK-LABEL: select_op %cst = arith.constant dense<0.000000e+00> : tensor<128xf32> @@ -95,3 +95,27 @@ tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg tt.return } } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func @arith_splat_bool(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-LABEL: arith_splat_bool + + // Test arith.constant with splatted bool. + // CHECK-NEXT: arith.constant dense : tensor<128xi1, #{{.*}}> + %mask = arith.constant dense : tensor<128xi1> + tt.return +} +} + +// ----- + +// CHECK-LABEL: gather_op +tt.func @gather_op() { + %cst = arith.constant dense<1.0> : tensor<128x4xf32> + %cst_0 = arith.constant dense<1> : tensor<256x4xi32> + // CHECK: tt.gather %{{.*}}[%{{.*}}] {axis = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> + %0 = tt.gather %cst[%cst_0] {axis = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> + tt.return +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e1a2ec68bd5a..a97ac476cbad 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { @@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -27,8 +27,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: vectorized_load tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -42,8 +42,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: vectorized_load_f16 tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { // CHECK: llvm.inline_asm @@ -58,8 +58,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: masked_load_const_other tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -71,8 +71,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // TODO: masked load with vectorization is pending on TODO -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: masked_load_const_other_vec tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> @@ -83,8 +83,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_with_cache_attr tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -98,8 +98,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_no_vec tt.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -150,8 +150,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_vec4 tt.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -187,9 +187,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // This test verifies the vectorization of Load and Store Ops. -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id x : i32 @@ -217,8 +217,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -262,8 +262,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -307,8 +307,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 @@ -349,9 +349,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_view_broadcast tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { // CHECK: llvm.mlir.undef @@ -374,8 +374,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: basic_make_range tt.func @basic_make_range() { // CHECK: nvvm.read.ptx.sreg.tid.x @@ -389,8 +389,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addf tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { // CHECK: llvm.fadd @@ -402,8 +402,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addi tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.add @@ -415,7 +415,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_program_id tt.func @basic_program_id() { // CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32 @@ -426,8 +426,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addptr tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr @@ -439,23 +439,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor tt.func @basic_alloc_tensor() { // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #smem, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_subview tt.func @basic_subview() { @@ -477,30 +479,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> + %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable> tt.return } } // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_async_wait tt.func @basic_async_wait() { // CHECK: cp.async.wait_group 0x4 - triton_gpu.async_wait {num = 4: i32} + ttg.async_wait {num = 4: i32} tt.return } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}> -#shared1D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> -#shared2D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#slice1d0 = #ttg.slice<{dim = 0, parent = #blocked1}> +#shared1D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> +#shared2D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: basic_insert_slice_async_1d tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %c0_i32 = arith.constant 0 : i32 @@ -509,10 +512,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> - %subview = triton_gpu.memdesc_subview %71[%c0_i32, %c0_i32] : - !tt.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> -> - !tt.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> + %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> + %subview = ttg.memdesc_subview %71[%c0_i32, %c0_i32] : + !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> -> + !ttg.memdesc<64xi64, #shared1D, #smem, mutable> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -523,23 +526,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = triton_gpu.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group %73 + %73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #smem, mutable> + ttg.async_commit_group %73 tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -551,35 +555,36 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<16x64xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x64xi32, #block3> -> tensor<16x64xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<16x64xi32, #block2> -> tensor<16x64xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<16x64xi32, #block3> -> tensor<16x64xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x64xi32, #block2> -> tensor<16x64xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x64xi32, #block3> -> tensor<16x64xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> + ttg.async_commit_group tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> @@ -591,12 +596,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<16x32xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<16x32xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<16x32xi32, #block2> -> tensor<16x32xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<16x32xi32, #block3> -> tensor<16x32xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x32xi32, #block2> -> tensor<16x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x32xi32, #block3> -> tensor<16x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -609,22 +614,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> + ttg.async_commit_group tt.return } } // ----- -#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> -#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> @@ -636,12 +642,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = tt.splat %cst_scalar : i32 -> tensor<32x32xi32, #block2> %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<32x32xi32, #block2> %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<32x32xi32, #block3> - %broadcast_off0 = triton_gpu.convert_layout %broadcast_off0_ : tensor<32x32xi32, #block2> -> tensor<32x32xi32, #AL> - %broadcast_off1 = triton_gpu.convert_layout %broadcast_off1_ : tensor<32x32xi32, #block3> -> tensor<32x32xi32, #AL> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<32x32xi32, #block2> -> tensor<32x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<32x32xi32, #block3> -> tensor<32x32xi32, #AL> %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 @@ -665,16 +671,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> - triton_gpu.async_commit_group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> + ttg.async_commit_group tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: basic_splat tt.func @basic_splat(%ptr: !tt.ptr) { // CHECK: llvm.mlir.undef @@ -687,8 +693,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store tt.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm @@ -702,9 +708,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { @@ -712,16 +718,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared // CHECK-: nvvm.barrier0 // CHECK-COUNT-8: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked_vec tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { @@ -733,16 +739,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.load // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { @@ -758,29 +764,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.load // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -794,19 +801,48 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : } // TODO: problems in MLIR's parser on slice layout -// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -// module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +// #blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // tt.func @make_range_sliced_layout() { -// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> +// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked0}>> // tt.return // } // } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot_fp8 + tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf8E5M2, #dot_operand_a> * tensor<16x16xf8E5M2, #dot_operand_b> -> tensor<16x16xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav2_block tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { @@ -816,55 +852,210 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> + %0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK: llvm.mlir.global external @global_smem - // CHECK-LABEL: convert_layout_mmav1_block - tt.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #ttg.slice<{dim = 0, parent = #mma}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg + tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav3_transpose tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> - %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> + %0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_shared tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { @@ -872,42 +1063,42 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { // CHECK: llvm.load {{.*}} -> vector<4xi32> - %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-8: llvm.load {{.*}} -> i32 - %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked_to_blocked_ptr tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { // CHECK: llvm.ptrtoint @@ -915,28 +1106,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.barrier0 // CHECK: llvm.inttoptr // CHECK-COUNT-4: llvm.insertvalue - %cvt = triton_gpu.convert_layout %src : tensor<32x!tt.ptr, #blocked0> -> tensor<32x!tt.ptr, #blocked1> + %cvt = ttg.convert_layout %src : tensor<32x!tt.ptr, #blocked0> -> tensor<32x!tt.ptr, #blocked1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> @@ -947,42 +1139,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory>, %b:!tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory>) { - %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> - // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory> -> tensor<32x64xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf16, #dot_operand_b> - - %28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> - %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> - %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> - tt.store %36, %38 : tensor<32x64x!tt.ptr, #blocked> - tt.return - } -} - -// ----- - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -994,15 +1162,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1010,8 +1179,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1022,7 +1191,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> @@ -1033,8 +1202,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1048,7 +1217,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1061,8 +1230,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1076,8 +1245,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_nomask // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 @@ -1089,8 +1258,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_withmask // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 @@ -1104,8 +1273,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1119,7 +1288,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32_scalar tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { // CHECK: llvm.icmp "eq" @@ -1132,8 +1301,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { %blockidx = tt.get_program_id x: i32 @@ -1154,8 +1323,8 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { %blockidx = tt.get_program_id x: i32 @@ -1176,8 +1345,8 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_num_program tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs x : i32 @@ -1197,8 +1366,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs x : i32 %blockdimy = tt.get_num_programs y : i32 @@ -1216,8 +1385,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_index_cache tt.func @test_index_cache() { // CHECK: nvvm.read.ptx.sreg.tid.x @@ -1228,29 +1397,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1259,22 +1430,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32_cst_b tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { // CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to i32 - // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to f32 + // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a, %b_mat, %c, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> - %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> tt.store %36, %38 : tensor<32x32x!tt.ptr, #blocked> @@ -1284,34 +1455,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_f16_cst_operands tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK: %[[C1f:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 - // CHECK: %[[Ci16:.+]] = llvm.bitcast %[[C1f]] : f16 to i16 - // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xi16> + // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xf16> // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[V0:.+]] = llvm.insertelement %[[Ci16]], %[[U]][%[[C0]] : i32] : vector<2xi16> + // CHECK: %[[V0:.+]] = llvm.insertelement %{{.*}}, %[[U]][%[[C0]] : i32] : vector<2xf16> // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[V1:.+]] = llvm.insertelement %[[Ci16]], %[[V0]][%[[C1]] : i32] : vector<2xi16> - // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xi16> to i32 - // CHECK: %[[SU:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SU]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: %[[V1:.+]] = llvm.insertelement %{{.*}}, %[[V0]][%[[C1]] : i32] : vector<2xf16> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xf16> to i32 + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %4 = arith.muli %3, %cst_2 : tensor<32x1xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %11 = tt.addptr %9, %10 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> @@ -1322,8 +1489,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_conversion tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) { // We can't vectorize if we only process @@ -1336,9 +1503,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_s8_to_bf16_vectorized_conversion tt.func @test_s8_to_bf16_vectorized_conversion(%in: tensor<16x16xi8, #mma>) { // CHECK-NOT: llvm.sitofp @@ -1360,19 +1527,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: nvvm.shfl.sync bfly // CHECK: nvvm.shfl.sync bfly // CHECK: nvvm.barrier0 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sum_reduction(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<1024> : tensor<1x1xi32, #blocked> %0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> - %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked> + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked> %3 = arith.muli %2, %cst : tensor<1x1xi32, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> - %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked> + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked> %8 = tt.broadcast %5 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %7 : tensor<1x1024x!tt.ptr, #blocked>, tensor<1x1024xi32, #blocked> %10 = tt.load %9 : tensor<1x1024x!tt.ptr, #blocked> @@ -1380,8 +1547,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg2: i32, %arg3: i32): %15 = arith.addi %arg2, %arg3 : i32 tt.reduce.return %15 : i32 - }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %12 = triton_gpu.convert_layout %11 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = ttg.convert_layout %11 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr, #blocked1> %14 = tt.addptr %13, %0 : tensor<1x!tt.ptr, #blocked1>, tensor<1xi32, #blocked1> tt.store %14, %12 : tensor<1x!tt.ptr, #blocked1> @@ -1390,9 +1557,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { // CHECK-LABEL: reduce_bools tt.func public @reduce_bools(%arg: tensor<256x2xi1, #blocked>) { // CHECK: llvm.mlir.addressof @global_smem @@ -1408,8 +1575,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: inline_asm tt.func public @inline_asm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> @@ -1427,8 +1594,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: inline_asm_pack_16bit tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> @@ -1449,16 +1616,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: reduce_slice // CHECK-NOT: st.shared // CHECK-NOT: ld.shared -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> -#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#sliced2 = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_slice() attributes {noinline = false} { %cst = arith.constant dense : tensor<4x1xi1, #sliced2> %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ ^bb0(%arg0: i1, %arg1: i1): %1 = arith.ori %arg0, %arg1 : i1 tt.reduce.return %1 : i1 - }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>> + }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #ttg.slice<{dim = 1, parent = #sliced2}>> tt.return } } @@ -1470,56 +1637,42 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: st.shared // CHECK: ld.shared // CHECK: st.shared -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 2], order = [2, 1, 0]}> -#sliced = #triton_gpu.slice<{dim = 2, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 2], order = [2, 1, 0]}> +#sliced = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_md_slice(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<2x128xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<2x128xf32, #ttg.slice<{dim = 2, parent = #blocked}>> %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %18 = arith.maxnumf %arg1, %arg2 : f32 tt.reduce.return %18 : f32 - }) {allocation.offset = 0 : i32} : (tensor<2x128xf32, #sliced>) -> tensor<2xf32, #triton_gpu.slice<{dim = 1, parent = #sliced}>> - tt.return - } -} - -// ----- - -// CHECK-LABEL: volta_dot -#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16]}> -module attributes {"triton_gpu.target" = "cuda:70", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - tt.func @volta_dot() { - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %a = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %b = arith.constant dense<0.000000e+00> : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> - - %87 = tt.dot %a, %b, %cst : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<32x32xf32, #mma> + }) {allocation.offset = 0 : i32} : (tensor<2x128xf32, #sliced>) -> tensor<2xf32, #ttg.slice<{dim = 1, parent = #sliced}>> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @i16_mma_layout(%f16_inp: tensor<16x16xf16, #blocked0>, %i16_inp: tensor<16x16xi16, #blocked0>) { // CHECK-LABEL: @i16_mma_layout - %f16_shared = triton_gpu.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> - %i16_shared = triton_gpu.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !tt.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> + %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %f16_dot = triton_gpu.local_load %f16_shared : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %i16_dot = triton_gpu.local_load %i16_shared : !tt.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xi16, #dot_operand_b> + %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b> // CHECK: llvm.sitofp %{{.*}} : i16 to f16 @@ -1539,26 +1692,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: convert_single_element // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load - // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue + // CHECK: llvm.return tt.func public @convert_single_element() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> - %0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: convert_single_element_and_add // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load @@ -1567,7 +1719,7 @@ module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : tt.func public @convert_single_element_and_add() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> %cst2 = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked> - %0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> %1 = arith.addf %0, %cst2 : tensor<1xf32, #blocked> tt.return } @@ -1575,38 +1727,40 @@ module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_load // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory>) { - %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #smem>) { + %0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #smem> -> tensor<16x16xi8, #blocked> tt.return } } // ----- -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_store // CHECK: llvm.store // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #smem> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: abs_is_int_min_poison // CHECK: %{{.*}} = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32 tt.func @abs_is_int_min_poison(%arg0 : tensor<256xi32, #blocked0>) { @@ -1616,54 +1770,57 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_load_bf16 // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> - %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> - %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> + %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> + %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store_subview // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> tt.return } } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: print_ptr // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr, #blocked0>) { @@ -1673,8 +1830,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // Test that %u format specifier is used if isSigned is false // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %u{{.*}}") // CHECK-LABEL: print_int32_tensor_issigned_off @@ -1686,8 +1843,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // Test that %i format specifier is used if isSigned is true // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %i{{.*}}") // CHECK-LABEL: print_int32_tensor_issigned_on @@ -1700,8 +1857,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @int32_to_bf16(%arg0: tensor<256xi32, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @int32_to_bf16 // CHECK: llvm.sitofp %{{.*}} : i32 to bf16 @@ -1712,8 +1869,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @bf16_to_int32(%arg0: tensor<256xbf16, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @bf16_to_int32 // CHECK: llvm.fptosi %{{.*}} : bf16 to i32 @@ -1724,11 +1881,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32} -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: llvm.call @__assertfail +// CHECK: nvvm.barrier0 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) { tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5) tt.return @@ -1739,3 +1898,149 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #loc3 = loc("inner_call":29:28) #loc4 = loc(callsite(#loc3 at #loc1)) #loc5 = loc(callsite(#loc4 at #loc2)) + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} { + // CHECK: log1pf_scan + // non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable. + // CHECK-NOT: llvm.cond_br + %40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32 + %44 = arith.addf %43, %43 : f32 + tt.scan.return %44 : f32 + }) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked> + tt.return + } +} + +// ----- + +// CHECK: inline_asm_pack +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 + tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { + // CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)> + %83:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %80 : tensor<64x64xi8, #blocked> -> tensor<64x64xbf16, #blocked>, tensor<64x64xbf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { + // CHECK-LABEL: gather_in_shared + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x4xi32, #blocked1>) -> tensor<16x4xf32, #blocked1> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}> +#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { + // CHECK-LABEL: gather_in_shared_dot_input + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + // CHECK: [[S1:%.*]] = llvm.extractvalue %arg1[1] + // CHECK: [[S2:%.*]] = llvm.extractvalue %arg1[2] + // CHECK: [[S3:%.*]] = llvm.extractvalue %arg1[3] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK: store [[S1]] + // CHECK: store [[S2]] + // CHECK: store [[S3]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x4xi32, #blocked>) -> tensor<16x4xf32, #blocked> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + + tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) { + // CHECK-LABEL: upcast_mxfp + // CHECK-COUNT-4: llvm.inline_asm + // CHECK-COUNT-2: nvvm.shfl.sync + // CHECK-COUNT-32: llvm.fmul + %0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + tt.return +} + +} diff --git a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir index 49128064a83e..f45143678ce2 100644 --- a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir +++ b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -1,10 +1,10 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s // CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 -#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -13,10 +13,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 -#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -25,10 +25,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } @@ -37,10 +37,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> // CHECK-NOT: load tt.return } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 83653d57b65e..56e078463d20 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,11 +1,12 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_high_precision_acc - tt.func @dot_high_precision_acc(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_high_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma @@ -14,21 +15,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : - !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_low_precision_acc - tt.func @dot_low_precision_acc(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_low_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -38,21 +40,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : - !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_mix_precision_acc - tt.func @dot_mix_precision_acc(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_mix_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -62,71 +65,90 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c + %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_zero_acc // Generate a wgmma with 2 sources. // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { - tt.func @dot_zero_acc(%a: !tt.memdesc<128x64xf16, #shared>, %b: !tt.memdesc<64x64xf16, #shared1>) { + tt.func @dot_zero_acc(%a: !ttg.memdesc<128x64xf16, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared1, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : - !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : + !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %opA = ttg.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %m = ttng.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: + tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !ttg.memdesc<128x256xf8E5M2, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + %m = ttng.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : + tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<128x256xf8E5M2, #shared, #smem> -> tensor<128x256xf32, #mma1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: dot_reg_operand_upcast + tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, @@ -153,9 +175,9 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // CHECK-LABEL: clamp -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> @@ -168,23 +190,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> // CHECK-LABEL: convert_mma_to_blocked -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) { // CHECK-COUNT-16: nvgpu.stmatrix // CHECK: nvvm.barrier0 - %c = triton_gpu.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> + %c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: cvt_mma_to_dot_fp8 // CHECK: prmt.b32 // CHECK: prmt.b32 @@ -193,23 +215,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = ttg.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_zero_acc_operand // CHECK-COUNT-128: llvm.fadd - tt.func @dot_zero_acc_operand(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x128xf8E5M2, #shared1>) { + tt.func @dot_zero_acc_operand(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x128xf8E5M2, #shared1, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x128xf8E5M2, #shared1, #smem> -> tensor<128x128xf32, #mma> tt.return } } @@ -217,22 +240,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory // CHECK-LABEL: distribute_to_shared_st_matrix -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) { // CHECK-COUNT-16: nvgpu.stmatrix // CHECK: llvm.return - %b = triton_gpu.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !tt.memdesc<128x128xf16, #shared, mutable> + %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { // CHECK-LABEL: @fp8_const // CHECK: llvm.mlir.constant(0.000000e+00 : f8E4M3FNUZ) : i8 @@ -244,8 +268,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_nomask // CHECK: atom.global.gpu.acq_rel.add.v4.f32 @@ -256,8 +280,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f32_withmask // CHECK: atom.global.gpu.acq_rel.add.v2.f32 @@ -269,8 +293,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { // CHECK-LABEL: atomic_add_f16_withmask // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 @@ -279,3 +303,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: test_fp8_to_fp16_dot_operand + // CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2 + tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) { + %r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir new file mode 100644 index 000000000000..1003f321d6f3 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir @@ -0,0 +1,44 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_volta.mlir b/test/Conversion/tritongpu_to_llvm_volta.mlir index 26010b88bd78..a5a428129416 100644 --- a/test/Conversion/tritongpu_to_llvm_volta.mlir +++ b/test/Conversion/tritongpu_to_llvm_volta.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=70 2>&1 | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> // CHECK-LABEL: clamp -module attributes {"triton_gpu.target" = "cuda:70", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 0bcab369f79f..127c4951e383 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,25 +1,27 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier - tt.func @init_barrier(%alloc: !tt.memdesc<1xi64, #shared0>) { + tt.func @init_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) { // CHECK: "@$0 mbarrier.init.shared::cta.b64 [$1], 1;", "b,r" %{{.*}}, %{{.*}} : (i1, !llvm.ptr<3>) -> !llvm.void - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier - tt.func @wait_barrier(%alloc: !tt.memdesc<1xi64, #shared0>, %phase: i32) { + tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %phase: i32) { // CHECK: waitLoop: // CHECK: mbarrier.try_wait.parity.shared.b64 // CHECK: @!P1 bra.uni waitLoop - triton_nvidia_gpu.wait_barrier %alloc, %phase : !tt.memdesc<1xi64, #shared0> + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0, #smem> tt.return } } @@ -27,62 +29,65 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_global_to_local // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.shared // CHECK: return - tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !tt.memdesc<128x128xf32, #shared1, mutable>, %x: i32, %barrier: !tt.memdesc<1xi64, #shared0>, %pred: i1) { - triton_nvidia_gpu.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !tt.memdesc<1xi64, #shared0> -> !tt.memdesc<128x128xf32, #shared1, mutable> + tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> tt.return } } // ----- -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_local_to_global // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.global.shared::cta.bulk_group // CHECK: cp.async.bulk.commit_group - tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !tt.memdesc<128x128xf32, #shared1>, %x: i32) { - triton_nvidia_gpu.async_tma_copy_local_to_global %tma[%x, %x] %alloc : , <128x128xf32, #shared1> + tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.ptr, !ttg.memdesc<128x128xf32, #shared1, #smem> tt.return } } // ----- -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: async_tma_store_wait // CHECK: "cp.async.bulk.wait_group.read 0x0;", "" : () -> !llvm.void tt.func @async_tma_store_wait() { - triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} + ttng.async_tma_store_wait {pendings = 0 : i32} tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: expect_barrier // CHECK: @$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], 16384; - tt.func @expect_barrier(%barrier: !tt.memdesc<1xi64, #shared0, mutable>, %pred: i1) { - triton_nvidia_gpu.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, mutable> + tt.func @expect_barrier(%barrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>, %pred: i1) { + ttng.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, #smem, mutable> tt.return } } // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: byval_tma_desc // CHECK: llvm.align = 64 // CHECK: llvm.byval = !llvm.array<128 x i8> @@ -95,7 +100,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // CHECK-LABEL: device_tensormap_create1d -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @device_tensormap_create1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c256_i32 = arith.constant 256 : i32 %c1_i32 = arith.constant 1 : i32 @@ -120,7 +125,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: device_tensormap_create2d -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @device_tensormap_create2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c256_i32 = arith.constant 256 : i32 %c1_i32 = arith.constant 1 : i32 @@ -150,7 +155,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: tensormap_fenceproxy_acquire -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @tensormap_fenceproxy_acquire(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { // CHECK: fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80; tt.experimental_tensormap_fenceproxy_acquire %arg0 : !tt.ptr diff --git a/test/Proton/ops.mlir b/test/Proton/ops.mlir new file mode 100644 index 000000000000..22a17e3f0f58 --- /dev/null +++ b/test/Proton/ops.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s + +module { + // CHECK-LABEL: proton_record + tt.func @proton_record() { + // CHECK: proton.record() {isStart = true, regionId = 1 : i32} + // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32} + // CHECK-NEXT: tt.return + proton.record() {isStart = true, regionId = 1 : i32} + proton.record() {isStart = false, regionId = 1 : i32} + tt.return + } +} // end module + +// ----- diff --git a/test/Tools/tensor_layout_print.mlir b/test/Tools/tensor_layout_print.mlir index 80c01959341d..9f802d2e3bfe 100644 --- a/test/Tools/tensor_layout_print.mlir +++ b/test/Tools/tensor_layout_print.mlir @@ -2,19 +2,19 @@ // RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA -// RUN: triton-tensor-layout -l "#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA +// RUN: triton-tensor-layout -l "#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA // RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" -use-hw-view | FileCheck %s --check-prefix=CHECK-HW -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> tt.func @print(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked> %cst1 = arith.constant dense<0.00e+00> : tensor<16x16xf16, #mfma> tt.return } -// CHECK-BLOCKED: Print layout attribute: #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-BLOCKED: Print layout attribute: #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> // CHECK-BLOCKED: T0:0| T4:0, T0:1| T4:1, T0:2| T4:2, T0:3| T4:3, T1:0| T5:0, T1:1| T5:1, T1:2| T5:2, T1:3| T5:3, T2:0| T6:0, T2:1| T6:1, T2:2| T6:2, T2:3| T6:3, T3:0| T7:0, T3:1| T7:1, T3:2| T7:2, T3:3| T7:3 // CHECK-BLOCKED: T8:0| T12:0, T8:1| T12:1, T8:2| T12:2, T8:3| T12:3, T9:0| T13:0, T9:1| T13:1, T9:2| T13:2, T9:3| T13:3, T10:0| T14:0, T10:1| T14:1, T10:2| T14:2, T10:3| T14:3, T11:0| T15:0, T11:1| T15:1, T11:2| T15:2, T11:3| T15:3 // CHECK-BLOCKED: T16:0| T20:0, T16:1| T20:1, T16:2| T20:2, T16:3| T20:3, T17:0| T21:0, T17:1| T21:1, T17:2| T21:2, T17:3| T21:3, T18:0| T22:0, T18:1| T22:1, T18:2| T22:2, T18:3| T22:3, T19:0| T23:0, T19:1| T23:1, T19:2| T23:2, T19:3| T23:3 @@ -33,7 +33,7 @@ tt.func @print(%A : !tt.ptr) { // CHECK-BLOCKED: T120:0|T124:0, T120:1|T124:1, T120:2|T124:2, T120:3|T124:3, T121:0|T125:0, T121:1|T125:1, T121:2|T125:2, T121:3|T125:3, T122:0|T126:0, T122:1|T126:1, T122:2|T126:2, T122:3|T126:3, T123:0|T127:0, T123:1|T127:1, T123:2|T127:2, T123:3|T127:3 -// CHECK-MFMA: Print layout attribute: {{.*}}#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// CHECK-MFMA: Print layout attribute: {{.*}}#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> // CHECK-MFMA: T0:0| T64:0|T128:0|T192:0, T0:1| T64:1|T128:1|T192:1, T0:2| T64:2|T128:2|T192:2, T0:3| T64:3|T128:3|T192:3, T16:0| T80:0|T144:0|T208:0, T16:1| T80:1|T144:1|T208:1, T16:2| T80:2|T144:2|T208:2, T16:3| T80:3|T144:3|T208:3, T32:0| T96:0|T160:0|T224:0, T32:1| T96:1|T160:1|T224:1, T32:2| T96:2|T160:2|T224:2, T32:3| T96:3|T160:3|T224:3, T48:0|T112:0|T176:0|T240:0, T48:1|T112:1|T176:1|T240:1, T48:2|T112:2|T176:2|T240:2, T48:3|T112:3|T176:3|T240:3 // CHECK-MFMA: T1:0| T65:0|T129:0|T193:0, T1:1| T65:1|T129:1|T193:1, T1:2| T65:2|T129:2|T193:2, T1:3| T65:3|T129:3|T193:3, T17:0| T81:0|T145:0|T209:0, T17:1| T81:1|T145:1|T209:1, T17:2| T81:2|T145:2|T209:2, T17:3| T81:3|T145:3|T209:3, T33:0| T97:0|T161:0|T225:0, T33:1| T97:1|T161:1|T225:1, T33:2| T97:2|T161:2|T225:2, T33:3| T97:3|T161:3|T225:3, T49:0|T113:0|T177:0|T241:0, T49:1|T113:1|T177:1|T241:1, T49:2|T113:2|T177:2|T241:2, T49:3|T113:3|T177:3|T241:3 // CHECK-MFMA: T2:0| T66:0|T130:0|T194:0, T2:1| T66:1|T130:1|T194:1, T2:2| T66:2|T130:2|T194:2, T2:3| T66:3|T130:3|T194:3, T18:0| T82:0|T146:0|T210:0, T18:1| T82:1|T146:1|T210:1, T18:2| T82:2|T146:2|T210:2, T18:3| T82:3|T146:3|T210:3, T34:0| T98:0|T162:0|T226:0, T34:1| T98:1|T162:1|T226:1, T34:2| T98:2|T162:2|T226:2, T34:3| T98:3|T162:3|T226:3, T50:0|T114:0|T178:0|T242:0, T50:1|T114:1|T178:1|T242:1, T50:2|T114:2|T178:2|T242:2, T50:3|T114:3|T178:3|T242:3 diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index 8888271e3c2b..ef448d500e68 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -11,6 +11,8 @@ tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr>) { tt.return } +// ----- + // CHECK-LABEL: make_range tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { // CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32> @@ -25,6 +27,32 @@ tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> } +// ----- + +// CHECK-LABEL: fold_addptr +tt.func @fold_addptr(%arg: tensor<64x64x!tt.ptr>) -> (tensor<64x64x!tt.ptr>) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant dense<0> : tensor<64x64xi32> + %0 = tt.addptr %arg, %c0_i32 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + tt.return %0 : tensor<64x64x!tt.ptr> +} + +// ----- + +// CHECK-LABEL: fold_addptr_scalar +tt.func @fold_addptr_scalar(%arg: !tt.ptr) -> (!tt.ptr) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant 0 : i32 + %0 = tt.addptr %arg, %c0_i32 : !tt.ptr, i32 + tt.return %0 : !tt.ptr +} + +// ----- + // CHECK-LABEL: fold_advance tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr>) { %c0_i32 = arith.constant 0 : i32 @@ -34,14 +62,13 @@ tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr> } - // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#sliced0 = #triton_gpu.slice<{dim = 1, parent = #blocked0}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#sliced0 = #ttg.slice<{dim = 1, parent = #blocked0}> // CHECK-LABEL: fn -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ // CHECK: %[[a:.*]] = tt.expand_dims // CHECK: tt.broadcast %[[a]] @@ -50,3 +77,110 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ tt.return %b : tensor<32x1xf32, #blocked0> } } // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_pos_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fp_to_fp_pos_zero_fold + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_pos_zero_fold_scalar() -> f8E4M3FNUZ { + // CHECK-LABEL: fp_to_fp_pos_zero_fold_scalar + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant 0.000000e+00 : f8E4M3FNUZ + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant 0.00e+00 : f32 + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : f32 -> f8E4M3FNUZ + tt.return %cst_converted : f8E4M3FNUZ + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> { + // CHECK-LABEL: fp_to_fp_neg_zero_fold + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<-0.000000e+00> : tensor<32x128xf8E4M3FN, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FN, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FN, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fp_to_fp_neg_zero_fold + // We fold to the positive zero here given by definition f8E4M3FNUZ does not have negative zero encoding. + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fold_fp_to_fp_non_zero_nofold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fold_fp_to_fp_non_zero_nofold + // CHECK-NEXT: %[[cst:.+]] = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked> + // CHECK-NEXT: %[[cst_cvt:.+]] = tt.fp_to_fp %[[cst]] + // CHECK-NEXT: tt.return %[[cst_cvt]] + %cst = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fold_fp_to_fp_non_constant_nofold(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fold_fp_to_fp_non_constant_nofold + // CHECK-NEXT: %[[arg_cvt:.+]] = tt.fp_to_fp %arg0 + // CHECK-NEXT: tt.return %[[arg_cvt]] + %cst_converted = tt.fp_to_fp %arg0, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +// CHECK-LABEL: @fold_broadcast_constant_pattern +tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8x1xf32> + %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> + + // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> + tt.return %bst_out : tensor<8x2xf32> +} + +// ----- + +// CHECK-LABEL: @fold_transpose_constant +tt.func @fold_transpose_constant() -> tensor<128x16xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<128x16xf32> + %cst = arith.constant dense<1.0> : tensor<16x128xf32> + %r = tt.trans %cst {order = array} : tensor<16x128xf32> -> tensor<128x16xf32> + // CHECK-NEXT: tt.return %[[cst]] : tensor<128x16xf32> + tt.return %r : tensor<128x16xf32> +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 41a3ba15a8ee..ecaa60e53c7d 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -208,16 +208,6 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr, tensor<8xf32>, tensor<8xf32> } -// CHECK-LABEL: @test_combine_broadcast_constant_pattern -tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { - // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> - %const = arith.constant dense<1.0> : tensor<8x1xf32> - %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> - - // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> - tt.return %bst_out : tensor<8x2xf32> -} - // CHECK-LABEL: @test_canonicalize_masked_load_pattern tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { %true_mask = arith.constant dense : tensor<8xi1> diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c7df02322476..3e130c29031d 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -157,8 +157,8 @@ tt.func public @fn(%v1: tensor<4x128xf32>, %v2: tensor<4x128xi64>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} @@ -170,9 +170,9 @@ tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // ----- // Bad order; should be [1,0] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{order}} // expected-error @+1 {{op failed to infer returned types}} @@ -215,11 +215,11 @@ tt.func public @fn(%arg0: tensor<2xf32>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> // Bad order, should be [1,0]. -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} @@ -230,11 +230,11 @@ tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> // bad sizePerThread; should be [1,1]. -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [0,1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} @@ -246,7 +246,7 @@ tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // ----- // Valid ops. -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32>) { %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<16x32x64xf32> %b = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<32x16x64xf32> @@ -257,11 +257,11 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32>) { // ----- // Valid op with blocked encoding. -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2,3,4], threadsPerWarp = [2,4,2,2], warpsPerCTA = [4,2,4,2], order = [3,2,1,0], CTAsPerCGA = [1,2,2,2], CTASplitNum = [1,2,4,8], CTAOrder = [3,2,1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2,4,3,1], threadsPerWarp = [4,2,2,2], warpsPerCTA = [2,2,4,4], order = [1,2,0,3], CTAsPerCGA = [2,2,2,1], CTASplitNum = [2,8,4,1], CTAOrder = [1,2,0,3]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1,2,3,4], threadsPerWarp = [2,4,2,2], warpsPerCTA = [4,2,4,2], order = [3,2,1,0], CTAsPerCGA = [1,2,2,2], CTASplitNum = [1,2,4,8], CTAOrder = [3,2,1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4,3,1], threadsPerWarp = [4,2,2,2], warpsPerCTA = [2,2,4,4], order = [1,2,0,3], CTAsPerCGA = [2,2,2,1], CTASplitNum = [2,8,4,1], CTAOrder = [1,2,0,3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64xf32, #blocked2>) { %a = tt.trans %arg0 {order = array} : tensor<2x4x8x16xf32, #blocked> -> tensor<4x16x8x2xf32, #blocked1> %b = tt.trans %arg1 {order = array} : tensor<16x32x64xf32, #blocked2> -> tensor<32x16x64xf32, #blocked3> @@ -272,14 +272,15 @@ tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64 // ----- // Valid op with shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared3 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -tt.func public @fn(%arg0: !tt.memdesc<2x4x8x16xf32, #shared>, %arg1: !tt.memdesc<16x32xf32, #shared2>) { - %a = tt.trans %arg0 {order = array} : !tt.memdesc<2x4x8x16xf32, #shared> -> !tt.memdesc<4x16x8x2xf32, #shared1> - %b = tt.trans %arg1 {order = array} : !tt.memdesc<16x32xf32, #shared2> -> !tt.memdesc<32x16xf32, #shared3> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> +#shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared3 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared, #smem>, %arg1: !ttg.memdesc<16x32xf32, #shared2, #smem>) { + %a = ttg.memdesc_trans %arg0 {order = array} : !ttg.memdesc<2x4x8x16xf32, #shared, #smem> -> !ttg.memdesc<4x16x8x2xf32, #shared1, #smem> + %b = ttg.memdesc_trans %arg1 {order = array} : !ttg.memdesc<16x32xf32, #shared2, #smem> -> !ttg.memdesc<32x16xf32, #shared3, #smem> tt.return } } // end module @@ -287,9 +288,9 @@ tt.func public @fn(%arg0: !tt.memdesc<2x4x8x16xf32, #shared>, %arg1: !tt.memdesc // ----- // Invalid blocked encoding. -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { // expected-error @+1 {{type}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #blocked> -> tensor<32x16x64xf32, #blocked1> @@ -300,9 +301,9 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { // ----- // Invalid shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // expected-error @+1 {{type}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> @@ -312,7 +313,7 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -322,7 +323,7 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -332,7 +333,7 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32xf32>) { // expected-error @+1 {{order must be a permutation}} %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> @@ -343,12 +344,63 @@ tt.func public @fn(%arg0: tensor<16x32xf32>) { // ----- // Invalid tensor with shared encoding. -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { // expected-error @+1 {{has an invalid layout: Shared layout is not allowed on tensor type.}} %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> tt.return } } // end module + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{indices and output shapes must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512xf32> + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked>) { + // expected-error @below {{indices and output encodings must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> + tt.return +} +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf16>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and output element types must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf16>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and indices ranks must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { + // expected-error @below {{indices dimension 1 must match the corresponding input dimension}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> + tt.return +} +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{gather dimension must be less than the input rank}} + %0 = tt.gather %arg0[%arg1] {axis = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} diff --git a/test/Triton/loop-unroll.mlir b/test/Triton/loop-unroll.mlir index 9166630281e6..531a14fffad3 100644 --- a/test/Triton/loop-unroll.mlir +++ b/test/Triton/loop-unroll.mlir @@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { // CHECK: scf.for // CHECK: tt.load // CHECK-NOT: tt.load + // CHECK: tt.num_stages = 1 : i32 %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { %3 = tt.load %arg5 : tensor<256x!tt.ptr> %4 = arith.addf %arg4, %3 : tensor<256xf32> diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index c3b92b7ee403..eb7a63c340a7 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -244,9 +244,16 @@ tt.func @histogram(%0: tensor<512xi32>) { } // CHECK-LABEL: experimental_descriptor_load -tt.func @experimental_descriptor_load(%0: !tt.ptr) { - // CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.ptr -> tensor<128xf32> +tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { + // CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc> -> tensor<128xf32> %c0_i32 = arith.constant 0 : i32 - %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.ptr -> tensor<128xf32> + %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.tensordesc> -> tensor<128xf32> tt.return } + +// CHECK-LABEL: @gather_op +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x16xi32>) -> tensor<512x16xf32> { + // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + tt.return %0 : tensor<512x16xf32> +} diff --git a/test/Triton/reproducer.mlir b/test/Triton/reproducer.mlir index f2c3a0f8e8d3..5a6747d217a9 100644 --- a/test/Triton/reproducer.mlir +++ b/test/Triton/reproducer.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt --verify-diagnostics --dump-pass-pipeline --run-reproducer %s 2>&1 | FileCheck %s -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @triton__() attributes {noinline = false} { tt.return } diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 551c1f67b52a..1f7de7d6d939 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -51,80 +51,80 @@ module { // %c256_i32 = arith.constant 256 : i32 // %0 = tt.get_program_id x : i32 // %1 = arith.muli %0, %c256_i32 : i32 -// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %3 = tt.broadcast %1 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %5 = tt.broadcast %arg3 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %7 = tt.broadcast %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %9 = tt.broadcast %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %11 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> +// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg<"coalesced encoding">> +// %3 = tt.broadcast %1 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %4 = arith.addi %3, %2 : tensor<256xi32, #ttg<"coalesced encoding">> +// %5 = tt.broadcast %arg3 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #ttg<"coalesced encoding">>, tensor<256xi32, #ttg<"coalesced encoding">>) -> tensor<256xi1, #ttg<"coalesced encoding">> +// %7 = tt.broadcast %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %9 = tt.broadcast %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %11 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> // %12 = arith.index_cast %arg4 : i32 to index // %13 = arith.cmpi slt, %c0, %12 : index -// %14 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %15 = tt.broadcast %13 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %16 = arith.andi %6, %15 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %17 = triton_gpu.copy_async %8, %16, %14 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %18 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %19 = tt.broadcast %13 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %21 = triton_gpu.copy_async %10, %20, %18 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %22 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %24 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %14 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %15 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %16 = arith.andi %6, %15 : tensor<256xi1, #ttg<"coalesced encoding">> +// %17 = ttg.copy_async %8, %16, %14 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %18 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %19 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %20 = arith.andi %6, %19 : tensor<256xi1, #ttg<"coalesced encoding">> +// %21 = ttg.copy_async %10, %20, %18 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %22 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %24 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %26 = arith.cmpi slt, %c32, %12 : index -// %27 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %28 = tt.broadcast %26 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %29 = arith.andi %6, %28 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %30 = triton_gpu.copy_async %23, %29, %27 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %31 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %32 = tt.broadcast %26 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %34 = triton_gpu.copy_async %25, %33, %31 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %35 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %37 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %27 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %28 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %29 = arith.andi %6, %28 : tensor<256xi1, #ttg<"coalesced encoding">> +// %30 = ttg.copy_async %23, %29, %27 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %31 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %32 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %33 = arith.andi %6, %32 : tensor<256xi1, #ttg<"coalesced encoding">> +// %34 = ttg.copy_async %25, %33, %31 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %35 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %37 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %39 = arith.cmpi slt, %c64, %12 : index -// %40 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %41 = tt.broadcast %39 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %42 = arith.andi %6, %41 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %43 = triton_gpu.copy_async %36, %42, %40 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %44 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %45 = tt.broadcast %39 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %47 = triton_gpu.copy_async %38, %46, %44 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %48 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %50 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { -// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %57 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %59 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> +// %40 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %41 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %42 = arith.andi %6, %41 : tensor<256xi1, #ttg<"coalesced encoding">> +// %43 = ttg.copy_async %36, %42, %40 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %44 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %45 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %46 = arith.andi %6, %45 : tensor<256xi1, #ttg<"coalesced encoding">> +// %47 = ttg.copy_async %38, %46, %44 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %48 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %50 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index) { +// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #ttg<"coalesced encoding">> +// %56 = arith.addf %arg7, %55 : tensor<256xf32, #ttg<"coalesced encoding">> +// %57 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %59 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> // %61 = arith.addi %arg18, %c32 : index // %62 = arith.cmpi slt, %61, %12 : index -// %63 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %64 = tt.broadcast %62 : i1 -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %65 = arith.andi %64, %6 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %66 = triton_gpu.copy_async %arg17, %65, %63 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %67 = tt.broadcast %cst : f32 -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %68 = triton_gpu.copy_async %arg16, %65, %67 : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %69 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// %71 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index +// %63 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %64 = tt.broadcast %62 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %65 = arith.andi %64, %6 : tensor<256xi1, #ttg<"coalesced encoding">> +// %66 = ttg.copy_async %arg17, %65, %63 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %67 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %68 = ttg.copy_async %arg16, %65, %67 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %69 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %71 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index // } -// %53 = tt.broadcast %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> -// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding">> +// %53 = tt.broadcast %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// tt.store %54, %52#0, %6 : tensor<256xf32, #ttg<"coalesced encoding">> // tt.return // } // } diff --git a/test/TritonCPU/canonicalize.mlir b/test/TritonCPU/canonicalize.mlir new file mode 100644 index 000000000000..9e14645861bb --- /dev/null +++ b/test/TritonCPU/canonicalize.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-canonicalize | FileCheck %s + +// Fold transfer read and shape cast. + +// CHECK-LABEL: @fold_transfer_read_shape_cast +// CHECK: %[[VAL:.+]] = vector.transfer_read +// CHECK: vector.transfer_write %[[VAL]] + +module { + tt.func public @fold_transfer_read_shape_cast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %in_p = tt.make_tensor_ptr %arg0, [%c2_i64, %c2_i64, %c16_i64, %c16_i64], [%c512_i64, %c256_i64, %c16_i64, %c1_i64], [%c0_i32, %c0_i32, %c0_i32, %c0_i32] {order = array} : > + %out_p = tt.make_tensor_ptr %arg1, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %memref1 = triton_cpu.extract_memref %in_p : > -> memref<2x2x16x16xbf16, strided<[512, 256, 16, 1]>> + %indices1:4 = triton_cpu.extract_indices %in_p : > -> index, index, index, index + %val1 = vector.transfer_read %memref1[%indices1#0, %indices1#1, %indices1#2, %indices1#3], %cst {in_bounds = [true, true, true, true]} : memref<2x2x16x16xbf16, strided<[512, 256, 16, 1]>>, vector<1x1x16x16xbf16> + %val2 = vector.shape_cast %val1 : vector<1x1x16x16xbf16> to vector<16x16xbf16> + %memref2 = triton_cpu.extract_memref %out_p : > -> memref<16x16xbf16, strided<[16, 1]>> + %indices2:2 = triton_cpu.extract_indices %out_p : > -> index, index + vector.transfer_write %val2, %memref2[%indices2#0, %indices2#1] {in_bounds = [true, true]} : vector<16x16xbf16>, memref<16x16xbf16, strided<[16, 1]>> + tt.return + } +} diff --git a/test/TritonCPU/convert-atomic.mlir b/test/TritonCPU/convert-atomic.mlir new file mode 100644 index 000000000000..b0cad10f0d8f --- /dev/null +++ b/test/TritonCPU/convert-atomic.mlir @@ -0,0 +1,36 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-atomic-ops | FileCheck %s + +// Convert atomic ops with non-constant masks into scf.if + maskless atomic op. +// Check that the final tt.atomic_rmw only has 5 parameters (the 6th would be the mask). + +// CHECK-LABEL: @atomic_mask +// CHECK: %[[COND:.+]] = vector.extract %{{.+}}[[[#IDX:]]] : i1 from vector<16xi1> +// CHECK-NEXT: scf.if %[[COND]] -> (f32) { +// CHECK-NEXT: %[[OLD:.+]] = tt.atomic_rmw fadd, acq_rel, gpu, %{{[^%]+}} %{{[^%]+}} : (!tt.ptr, f32) -> f32 +// CHECK-NEXT: scf.yield %[[OLD]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: scf.yield %[[CST]] : f32 +// CHECK-NEXT: } + +module { + tt.func public @atomic_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xi64> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<5.000000e-01> : vector<16xf32> + %cst_1 = arith.constant dense<3.000000e+00> : vector<16xf32> + %0 = builtin.unrealized_conversion_cast %cst_1 : vector<16xf32> to tensor<16xf32> + %1 = tt.ptr_to_int %arg0 : !tt.ptr -> i64 + %2 = vector.splat %1 : vector<16xi64> + %3 = arith.addi %2, %cst : vector<16xi64> + %4 = builtin.unrealized_conversion_cast %3 : vector<16xi64> to tensor<16x!tt.ptr> + %5 = vector.extract %3[0] : i64 from vector<16xi64> + %6 = tt.int_to_ptr %5 : i64 -> !tt.ptr + %7 = triton_cpu.ptr_to_memref %6 : -> memref<16xf32> + %8 = vector.load %7[%c0] : memref<16xf32>, vector<16xf32> + %9 = arith.cmpf olt, %8, %cst_0 : vector<16xf32> + %10 = builtin.unrealized_conversion_cast %9 : vector<16xi1> to tensor<16xi1> + %11 = tt.atomic_rmw fadd, acq_rel, gpu, %4, %0, %10 : (tensor<16x!tt.ptr>, tensor<16xf32>, tensor<16xi1>) -> tensor<16xf32> + tt.return + } +} diff --git a/test/TritonCPU/convert-memory-ops.mlir b/test/TritonCPU/convert-memory-ops.mlir new file mode 100644 index 000000000000..710b76279610 --- /dev/null +++ b/test/TritonCPU/convert-memory-ops.mlir @@ -0,0 +1,81 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-gather-scatter=true -cse | FileCheck %s + +// Convert strided masked loads to gather. + +// CHECK-LABEL: @strided_masked_loads +// CHECK: %[[PTR:.+]] = triton_cpu.ptr_to_memref %[[BASE:.+]] : -> memref +// CHECK: %[[VAL:.+]] = vector.gather %[[PTR]][] [%[[INDEX_VEC:.+]]], %[[MASK:.+]], %[[OTHER:.+]] : memref, vector<32xi32>, vector<32xi1>, vector<32xi32> into vector<32xi32> + +module { + tt.func public @strided_masked_loads(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2> : tensor<32xi32> + %cst_0 = arith.constant dense<16> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<32xi32> + %2 = arith.muli %0, %cst : tensor<32xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<32x!tt.ptr>, tensor<32xi32> + scf.for %arg1 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %5 = tt.load %4, %1 : tensor<32x!tt.ptr> + tt.store %4, %5 : tensor<32x!tt.ptr> + } + tt.return + } +} + +// ----- + +// Convert strided masked stores to scatter. + +// CHECK-LABEL: @strided_masked_stores +// CHECK: %[[PTR:.+]] = triton_cpu.ptr_to_memref %[[BASE:.+]] : -> memref +// CHECK: vector.scatter %[[PTR]][] [%[[INDEX_VEC:.+]]], %[[MASK:.+]], %[[VALS:.+]] : memref, vector<32xi32>, vector<32xi1>, vector<32xi32> + +module { + tt.func public @strided_masked_stores(%arg0: !tt.ptr {tt.divisibility = 16 : i32} ) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<32xi32> + %cst_0 = arith.constant dense<2> : tensor<32xi32> + %cst_1 = arith.constant dense<16> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = arith.cmpi slt, %0, %cst_1 : tensor<32xi32> + %2 = arith.muli %0, %cst_0 : tensor<32xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<32x!tt.ptr>, tensor<32xi32> + %5 = arith.subi %cst, %2 : tensor<32xi32> + %6 = tt.addptr %3, %5 : tensor<32x!tt.ptr>, tensor<32xi32> + scf.for %arg1 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %7 = tt.load %4 : tensor<32x!tt.ptr> + tt.store %6, %7, %1 : tensor<32x!tt.ptr> + } + tt.return + } +} + +// ----- + +// Check that pointer for vector load/store is not extracted from a vector + +// CHECK-LABEL: @scalar_ptrs +// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64> +// CHECK: {{.+}} = vector.load {{.+}} : memref<128xf32>, vector<128xf32> +// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64> +// CHECK: vector.store {{.+}}, {{.+}} : memref<128xf32>, vector<128xf32> + +module { + tt.func public @scalar_ptrs(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %5, %3 : tensor<128x!tt.ptr> + tt.return + } +} diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir new file mode 100644 index 000000000000..da501849f723 --- /dev/null +++ b/test/TritonCPU/dot-to-amx.mlir @@ -0,0 +1,238 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-dot-to-amx="convert-bf16=true convert-fp16=true convert-i8=true" -canonicalize | FileCheck %s + +// Replacement of a contraction operation with a single tile_mulf operation. + +// CHECK-LABEL: @test_single_mulf +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x32xbf16> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC:.+]] = amx.tile_zero : !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS:.+]] = amx.tile_load %3[%4#0, %4#1] +// CHECK-NEXT: %[[RHS:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] +// CHECK-NEXT: %[[RES:.+]] = amx.tile_mulf %[[LHS]], %[[RHS]], %[[ACC]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES]] : memref<16x16xf32, strided<[16, 1]>>, !amx.tile<16x16xf32> + +#loc = loc(unknown) +module { + tt.func public @test_single_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf32> loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x32xbf16, strided<[32, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x32xbf16, strided<[32, 1]>>, vector<16x32xbf16> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<32x16xbf16, strided<[16, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<32x16xbf16, strided<[16, 1]>>, vector<32x16xbf16> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x32xbf16> * vector<32x16xbf16> -> vector<16x16xf32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[16, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// Replacement of a contraction operation with multiple tile_muli operations. + +// CHECK-LABEL: @test_single_tile_two_muli +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xi8> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC:.+]] = amx.tile_zero : !amx.tile<16x16xi32> +// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] +// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_muli %[[LHS1]], %[[RHS1]], %[[ACC]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> +// CHECK-NEXT: %[[IDX1:.+]] = arith.addi %4#1, %c64{{.*}} : index +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x128xi8, strided<[128, 1]>> into !amx.tile<16x64xi8> +// CHECK-NEXT: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xi8> into !amx.tile<16x64xi8> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_muli %[[LHS2]], %[[RHS2]], %[[RES1]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES2]] : memref<16x16xi32, strided<[16, 1]>>, !amx.tile<16x16xi32> + +#loc = loc(unknown) +module { + tt.func public @test_single_tile_two_muli(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %c0_i8 = arith.constant 0 : i8 loc(#loc) + %cst = arith.constant dense<0> : vector<16x16xi32> loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x128xi8, strided<[128, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %c0_i8 {in_bounds = [true, true]} : memref<16x128xi8, strided<[128, 1]>>, vector<16x128xi8> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<128x16xi8, strided<[16, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %c0_i8 {in_bounds = [true, true]} : memref<128x16xi8, strided<[16, 1]>>, vector<128x16xi8> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst, inputPrecision = ieee : vector<16x128xi8> * vector<128x16xi8> -> vector<16x16xi32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32, strided<[16, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// Replacement of a contraction operation with multiple tile_mulf operations +// and multiple output tiles. + +// CHECK-LABEL: @test_two_tiles_four_mulf +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xbf16> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC1:.+]] = amx.tile_zero : !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC2:.+]] = amx.tile_zero : !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] : memref<16x64xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS1]], %[[ACC1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS2]], %[[ACC2]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK: %[[IDX1:.+]] = arith.addi %4#1, %c32{{.*}} : index +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x64xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> +// CHECK: %[[RHS3:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES3:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS3]], %[[RES1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES3]] : memref<16x32xf32, strided<[32, 1]>>, !amx.tile<16x16xf32> +// CHECK: %[[RHS4:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES4:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS4]], %[[RES2]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK: %[[IDX2:.+]] = arith.addi %[[OUT_INDICES]]#1, %c16{{.*}} : index +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[IDX2]]], %[[RES4]] : memref<16x32xf32, strided<[32, 1]>>, !amx.tile<16x16xf32> + +#loc = loc(unknown) +module { + tt.func public @test_two_tiles_four_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x64xbf16, strided<[64, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x64xbf16, strided<[64, 1]>>, vector<16x64xbf16> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<64x32xbf16, strided<[32, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x64xbf16> * vector<64x32xbf16> -> vector<16x32xf32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x32xf32>, memref<16x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// More complicated case with a loop, input casts, and accumulator that +// cannot fit tile register file. + +// CHECK-LABEL: @test_loop_acc_two_blocks +// CHECK: %[[LHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<64x64xbf16> +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xbf16> +// CHECK: %[[ACC_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<64x32xf32> +// CHECK: vector.transfer_write %cst{{.+}}, %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] {in_bounds = [true, true]} : vector<64x32xf32>, memref<64x32xf32> +// CHECK: %3:2 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %1) -> (!tt.ptr>, !tt.ptr>) : i32 +// CHECK: %[[LHS:.+]] = vector.transfer_read %{{.+}}[%{{.+}}#0, %{{.+}}#1], %{{.+}} {in_bounds = [true, true]} : memref<64x128xf8E5M2, strided<[128, 1]>>, vector<64x64xf8E5M2> +// CHECK: %[[RHS:.+]] = vector.transfer_read %{{.+}}[%{{.+}}#0, %{{.+}}#1], %{{.+}} {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> +// CHECK-NEXT: %[[LHS1:.+]] = arith.extf %[[LHS]] : vector<64x64xf8E5M2> to vector<64x64xbf16> +// CHECK-NEXT: vector.transfer_write %[[LHS1]], %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16> +// CHECK-NEXT: %[[RHS1:.+]] = arith.extf %[[RHS]] : vector<64x32xf8E5M2> to vector<64x32xbf16> +// CHECK-COUNT-32: vector.store %{{.+}}, %[[RHS_BUF]][%{{.+}}, %{{.+}}] : memref<32x64xbf16>, vector<64xbf16> +// CHECK-NEXT: %[[ACC_0_0:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_0_1:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_1_0:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_1_1:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_0_0:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_0:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_0:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_0]], %[[ACC_0_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_1_0:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_0]], %[[ACC_1_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_1:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_1]], %[[ACC_0_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_1_1:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_1]], %[[ACC_1_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_0_1:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_1:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_0_0:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_0]], %[[TMP_0_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %[[RES_0_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_1_0:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_0]], %[[TMP_1_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}], %[[RES_1_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_0_1:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_1]], %[[TMP_0_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}], %[[RES_0_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_1_1:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_1]], %[[TMP_1_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}], %[[RES_1_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_2_0:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_2_1:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_3_0:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_3_1:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_2_0:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_0:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_0:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_0]], %[[ACC_2_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_3_0:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_0]], %[[ACC_3_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_1:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_1]], %[[ACC_2_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_3_1:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_1]], %[[ACC_3_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_2_1:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_1:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_2_0:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_0]], %[[TMP_2_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}], %[[RES_2_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_3_0:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_0]], %[[TMP_3_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}], %[[RES_3_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_2_1:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_1]], %[[TMP_2_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}], %[[RES_2_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_3_1:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_1]], %[[TMP_3_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}], %[[RES_3_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK: %[[RES:.+]] = vector.transfer_read %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<64x32xf32>, vector<64x32xf32> + +#loc = loc(unknown) +module { + tt.func public @test_loop_acc_two_blocks(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f8E5M2 loc(#loc) + %c2_i32 = arith.constant 2 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %c64_i32 = arith.constant 64 : i32 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<64x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3:3 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %cst_0, %arg5 = %0, %arg6 = %1) -> (vector<64x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %6 = triton_cpu.extract_memref %arg5 : > -> memref<64x128xf8E5M2, strided<[128, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %arg5 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x128xf8E5M2, strided<[128, 1]>>, vector<64x64xf8E5M2> loc(#loc) + %9 = triton_cpu.extract_memref %arg6 : > -> memref<128x32xf8E5M2, strided<[32, 1]>> loc(#loc) + %10:2 = triton_cpu.extract_indices %arg6 : > -> index, index loc(#loc) + %11 = vector.transfer_read %9[%10#0, %10#1], %cst {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> loc(#loc) + %12 = triton_cpu.dot %8, %11, %arg4, inputPrecision = ieee : vector<64x64xf8E5M2> * vector<64x32xf8E5M2> -> vector<64x32xf32> loc(#loc) + %13 = tt.advance %arg5, [%c0_i32, %c64_i32] : > loc(#loc) + %14 = tt.advance %arg6, [%c64_i32, %c0_i32] : > loc(#loc) + scf.yield %12, %13, %14 : vector<64x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %4 = triton_cpu.extract_memref %2 : > -> memref<64x32xf32, strided<[32, 1]>> loc(#loc) + %5:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %3#0, %4[%5#0, %5#1] {in_bounds = [true, true]} : vector<64x32xf32>, memref<64x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) diff --git a/test/TritonCPU/optimize-masks.mlir b/test/TritonCPU/optimize-masks.mlir new file mode 100644 index 000000000000..470e0f6b3419 --- /dev/null +++ b/test/TritonCPU/optimize-masks.mlir @@ -0,0 +1,93 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-optimize-masks -canonicalize | FileCheck %s + +// Convert strided masked loads to scalar loads. + +// CHECK-LABEL: @remove_masks_in_for_loop +// CHECK: %[[VAL:.+]] = vector.load {{.+}} : memref<16xf32>, vector<16xf32> +// CHECK: vector.store %[[VAL]], {{.+}} : memref<16xf32>, vector<16xf32> + +module { + tt.func public @remove_masks_in_for_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %c15_i32 = arith.constant 15 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = arith.addi %arg2, %c15_i32 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = vector.splat %arg2 : vector<16xi32> + scf.for %arg3 = %c0_i32 to %1 step %c1_i32 : i32 { + %3 = arith.muli %arg3, %c16_i32 : i32 + %4 = vector.splat %3 : vector<16xi32> + %5 = arith.addi %4, %cst : vector<16xi32> + %6 = arith.cmpi slt, %5, %2 : vector<16xi32> + %7 = tt.addptr %arg0, %3 : !tt.ptr, i32 + %8 = triton_cpu.ptr_to_memref %7 : -> memref<16xf32> + %9 = vector.maskedload %8[%c0], %6, %cst_0 : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %10 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %11 = triton_cpu.ptr_to_memref %10 : -> memref<16xf32> + vector.maskedstore %11[%c0], %6, %9 : memref<16xf32>, vector<16xi1>, vector<16xf32> + } + tt.return + } +} + +// ----- + +// Replace masked load with a regular load and optimize out arith.select. + +// CHECK-LABEL: @optimize_select +// CHECK: vector.load +// CHECK-NEXT: arith.addf +// CHECK-NEXT: arith.addf +// CHECK-NEXT: scf.yield + +module { + tt.func public @optimize_select(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %cst_1 = arith.constant dense<1.000000e+00> : vector<16xf32> + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = vector.splat %arg2 : vector<16xi32> + %1 = scf.for %arg3 = %c0_i32 to %arg2 step %c16_i32 iter_args(%arg4 = %cst_2) -> (vector<16xf32>) : i32 { + %3 = vector.splat %arg3 : vector<16xi32> + %4 = arith.addi %3, %cst_0 : vector<16xi32> + %5 = arith.cmpi slt, %4, %0 : vector<16xi32> + %6 = tt.addptr %arg0, %arg3 : !tt.ptr, i32 + %7 = triton_cpu.ptr_to_memref %6 : -> memref<16xf32> + %8 = vector.maskedload %7[%c0], %5, %cst_2 : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %9 = arith.addf %8, %cst_1 : vector<16xf32> + %10 = arith.select %5, %9, %cst_2 : vector<16xi1>, vector<16xf32> + %11 = arith.addf %arg4, %10 : vector<16xf32> + scf.yield %11 : vector<16xf32> + } + %2 = vector.multi_reduction , %1, %cst [0] : vector<16xf32> to f32 + tt.store %arg1, %2 : !tt.ptr + tt.return + } +} + +// ----- + +// Regression test for the infinite optimization loop bug. + +module { + tt.func public @remove_masks_in_for_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %c15_i32 = arith.constant 15 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = arith.addi %arg1, %c15_i32 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} diff --git a/test/TritonCPU/reduction.mlir b/test/TritonCPU/reduction.mlir new file mode 100644 index 000000000000..b3c1430e7b41 --- /dev/null +++ b/test/TritonCPU/reduction.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-reduction -canonicalize + +// Regression test: Check that we handle consecutive calls to tt.reduce with +// different types & number of arguments. + +module { + tt.func public @triton_(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>) { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: f32): + tt.reduce.return %arg3 : f32 + }) : (tensor<1x4xf32>) -> tensor<1xf32> + %1:2 = "tt.reduce"(%arg0, %arg1) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): + tt.reduce.return %arg3, %arg4 : f32, i32 + }) : (tensor<1x4xf32>, tensor<1x4xi32>) -> (tensor<1xf32>, tensor<1xi32>) + tt.return + } +} diff --git a/test/TritonCPU/scalarize-memory-ops.mlir b/test/TritonCPU/scalarize-memory-ops.mlir new file mode 100644 index 000000000000..f1934d9ffc14 --- /dev/null +++ b/test/TritonCPU/scalarize-memory-ops.mlir @@ -0,0 +1,113 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-scalarize -cse -canonicalize | FileCheck %s + +// Convert strided masked load and store to loops. Pointer and mask should be scalarized. +// TODO: There is an optimization opportunity to fuse loops. +// TODO: There is an optimization opportunity to reuse temp buffers. + +// CHECK-LABEL: @strided_masked_load_store +// CHECK: %[[ALLOCA1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: scf.for %[[IV1:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[IV1_I32:.*]] = arith.index_castui %[[IV1]] : index to i32 +// CHECK-NEXT: %[[IDX1:.*]] = arith.muli %[[IV1_I32]], %c3_i32 : i32 +// CHECK-NEXT: %[[PTR1:.*]] = tt.addptr %arg0, %[[IDX1]] : !tt.ptr, i32 +// CHECK-NEXT: %[[MASK1:.*]] = arith.cmpi slt, %[[IDX1]], %arg2 : i32 +// CHECK-NEXT: scf.if %[[MASK1]] { +// CHECK-NEXT: %[[VAL1:.*]] = tt.load %[[PTR1]] : !tt.ptr +// CHECK-NEXT: memref.store %[[VAL1]], %[[ALLOCA1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %[[TENSOR_VAL:.*]] = triton_cpu.load %[[ALLOCA1]] : memref<128xf32> -> tensor<128xf32> +// CHECK-NEXT: %[[ALLOCA2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: triton_cpu.store %[[TENSOR_VAL]], %[[ALLOCA2]] : tensor<128xf32>, memref<128xf32> +// CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[IV2_I32:.*]] = arith.index_castui %[[IV2]] : index to i32 +// CHECK-NEXT: %[[IDX2:.*]] = arith.muli %[[IV2_I32]], %c3_i32 : i32 +// CHECK-NEXT: %[[PTR2:.*]] = tt.addptr %arg1, %[[IDX2]] : !tt.ptr, i32 +// CHECK-NEXT: %[[MASK2:.*]] = arith.cmpi slt, %[[IDX2]], %arg2 : i32 +// CHECK-NEXT: %[[VAL2:.*]] = memref.load %[[ALLOCA2]][%[[IV2]]] : memref<128xf32> +// CHECK-NEXT: scf.if %[[MASK2]] { +// CHECK-NEXT: tt.store %[[PTR2]], %[[VAL2]] : !tt.ptr +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { + tt.func public @strided_masked_load_store(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %cst = arith.constant dense<1.000000e+00> : tensor<128xf32> + %cst_0 = arith.constant dense<3> : tensor<128xi32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = arith.muli %0, %cst_0 : tensor<128xi32> + %2 = tt.splat %arg2 : i32 -> tensor<128xi32> + %3 = arith.cmpi slt, %1, %2 : tensor<128xi32> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %1 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.load %5, %3, %cst : tensor<128x!tt.ptr> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %8 = tt.addptr %7, %1 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %8, %6, %3 : tensor<128x!tt.ptr> + tt.return + } +} + +// ----- + +// Convert indirect masked load and store. Pointer and mask are bufferized. +// TODO: There is an optimization opportunity to fuse loops. +// TODO: There is an optimization opportunity to reuse temp buffers. + +// CHECK-LABEL: @indirect_masked_load_store +// CHECK: %[[ALLOCA_VALS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: %[[ALLOCA_PTRS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> +// CHECK-NEXT: triton_cpu.store %{{.*}}, %[[ALLOCA_PTRS1]] : tensor<128x!tt.ptr>, memref<128xi64> +// CHECK-NEXT: %[[EXT_MASK:.*]] = arith.extui %{{.*}} : tensor<128xi1> to tensor<128xi8> +// CHECK-NEXT: %[[ALLOCA_MASK1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> +// CHECK-NEXT: triton_cpu.store %[[EXT_MASK]], %[[ALLOCA_MASK1]] : tensor<128xi8>, memref<128xi8> +// CHECK-NEXT: scf.for %[[IV1:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[PTR1_INT:.*]] = memref.load %[[ALLOCA_PTRS1]][%[[IV1]]] : memref<128xi64> +// CHECK-NEXT: %[[PTR1:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[MASK1_I8:.*]] = memref.load %[[ALLOCA_MASK1]][%[[IV1]]] : memref<128xi8> +// CHECK-NEXT: %[[MASK1:.*]] = arith.trunci %[[MASK1_I8]] : i8 to i1 +// CHECK-NEXT: scf.if %[[MASK1]] { +// CHECK-NEXT: %[[VAL1:.*]] = tt.load %[[PTR1]] : !tt.ptr +// CHECK-NEXT: memref.store %[[VAL1]], %[[ALLOCA_VALS1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA_VALS1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %[[TENSOR_VAL:.*]] = triton_cpu.load %[[ALLOCA_VALS1]] : memref<128xf32> -> tensor<128xf32> +// CHECK: %[[ALLOCA_PTRS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> +// CHECK-NEXT: triton_cpu.store %{{.*}}, %[[ALLOCA_PTRS2]] : tensor<128x!tt.ptr>, memref<128xi64> +// CHECK-NEXT: %[[ALLOCA_MASK2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> +// CHECK-NEXT: triton_cpu.store %[[EXT_MASK]], %[[ALLOCA_MASK2]] : tensor<128xi8>, memref<128xi8> +// CHECK-NEXT: %[[ALLOCA_VALS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: triton_cpu.store %[[TENSOR_VAL]], %[[ALLOCA_VALS2]] : tensor<128xf32>, memref<128xf32> +// CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[PTR2_INT:.*]] = memref.load %[[ALLOCA_PTRS2]][%[[IV2]]] : memref<128xi64> +// CHECK-NEXT: %[[PTR2:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[MASK2_I8:.*]] = memref.load %[[ALLOCA_MASK2]][%[[IV2]]] : memref<128xi8> +// CHECK-NEXT: %[[MASK2:.*]] = arith.trunci %[[MASK2_I8]] : i8 to i1 +// CHECK-NEXT: %[[VAL2:.*]] = memref.load %[[ALLOCA_VALS2]][%[[IV2]]] : memref<128xf32> +// CHECK-NEXT: scf.if %[[MASK2]] { +// CHECK-NEXT: tt.store %[[PTR2]], %[[VAL2]] : !tt.ptr +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { + tt.func public @indirect_masked_load_store(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg3 : i32 -> tensor<128xi32> + %5 = arith.cmpi slt, %3, %4 : tensor<128xi32> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.load %7, %5, %cst : tensor<128x!tt.ptr> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %10 = tt.addptr %9, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %10, %8, %5 : tensor<128x!tt.ptr> + tt.return + } +} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 85b37f3ed3a9..17180a392440 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,20 +1,20 @@ // RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s -// CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -// CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -// CHECK: #[[MMA2:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +// CHECK: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +// CHECK: #[[MMA2:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: mma_chain_loop tt.func public @mma_chain_loop( - %170: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %171: tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %179: tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, - %164: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>, - %165: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>, - %173: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %170: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %171: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %179: tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, + %164: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>, + %165: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>, + %173: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, %153: tensor<128x64x!tt.ptr, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %c8_i32 = arith.constant 8 : i32 @@ -23,21 +23,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { - %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> - %178 = triton_gpu.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %180 = tt.dot %178, %179, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %178 = ttg.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %180 = tt.dot %178, %179, %arg16 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> - // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { - %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> - %172 = triton_gpu.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %174 = tt.dot %172, %173, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %172 = ttg.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %174 = tt.dot %172, %173, %arg16 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %174 : tensor<128x64xf16, #blocked1> } tt.store %153, %149 : tensor<128x64x!tt.ptr, #blocked1> @@ -47,79 +47,106 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: chained_dot tt.func public @chained_dot( - %arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : - tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> - %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> // CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> %r = tt.dot %c, %arg2, %cst_1 : - tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> } } // ----- -// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +// CHECK: #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: chained_dot + tt.func public @chained_dot_wgmma( + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1> + %r = tt.dot %c, %arg2, %cst_1 : + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tt.return %r : tensor<64x128xf32, #blocked1> + } +} + +// ----- + +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fp8_dot tt.func public @fp8_dot( - %arg0: tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, - %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { + %arg0: tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> - // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : - tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> tt.return %d : tensor<64x64xf32, #blocked> } } // ----- -// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> +// CHECK-DAG: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +// CHECK-DAG: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: kernel_ tt.func public @kernel_() attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<2x16x16xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> - %0 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %1 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> - %2 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %1 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> + %2 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> - %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> - %4 = triton_gpu.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> - %6 = triton_gpu.convert_layout %5 : tensor<1x16x16xf32, #blocked2> -> tensor<1x16x16xf32, #blocked> + %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> + %6 = ttg.convert_layout %5 : tensor<1x16x16xf32, #blocked2> -> tensor<1x16x16xf32, #blocked> %7 = tt.broadcast %6 : tensor<1x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked> - %8 = triton_gpu.convert_layout %7 : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> - %9 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> - %10 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> + %8 = ttg.convert_layout %7 : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %9 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %10 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> - %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> - %12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> + %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> + %12 = ttg.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> tt.print ": " {hex = false, isSigned = array} : %12 : tensor<2x16x16xf32, #blocked> tt.return } @@ -127,17 +154,17 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: check_instrShape_per_warps tt.func @check_instrShape_per_warps(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { %mask = arith.constant dense : tensor<128x128xi1, #blocked> %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %b = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> - %result = tt.dot %a, %b, %zero_f32 : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> %result_ptr = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> tt.store %result_ptr, %result, %mask : tensor<128x128x!tt.ptr, #blocked> tt.return @@ -148,37 +175,89 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. -// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: small_k_size tt.func @small_k_size( - %a: tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %b: tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) + %a: tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<128x128xf32, #blocked> { %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } // ----- -// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: dot_scaled +// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[LINEAR:.+]] = #ttg.linear<{{.*}}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: dot_scaled tt.func @dot_scaled( - %a: tensor<128x64xi8, #blocked2>, + %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, - %b: tensor<64x128xi8, #blocked>) - -> tensor<128x128xf32, #blocked> { - // CHECK: triton_gpu.upcast_mxfp - // CHECK: tt.dot + %b_bf16: tensor<64x128xbf16, #blocked> + ) -> tensor<128x128xf32, #blocked> { + // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}> + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> + // CHECK: ttng.warp_group_dot + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } + + // Verify that dot_scaled (mxfp4 x fp8) decomposes into mmav2 + // CHECK: dot_scaled_fp8 + tt.func @dot_scaled_fp8( + %a: tensor<128x32xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b_fp8: tensor<64x128xf8E4M3FN, #blocked> + ) -> tensor<128x128xf32, #blocked> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> + // CHECK: tt.dot + %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: dot_scale_transpose + tt.func public @dot_scale_transpose(%arg0: tensor<128x64xf8E4M3FN, #blocked>, %arg1: tensor<32x32xi8, #blocked1>, %arg2: tensor<32x2xi8, #blocked2>, %arg3: tensor<128x32x!tt.ptr, #blocked3>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked3> + %cst_1 = arith.constant dense<2> : tensor<32x1xi32, #blocked2> + // CHECK: scf.for + %0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 { + // CHECK-DAG: tt.trans %{{.*}} {order = array} : tensor<128x64xf8E4M3FN, #{{.*}}> -> tensor<64x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: tt.trans %a{{.*}} {order = array} : tensor<32x32xi8, #{{.*}}> -> tensor<32x32xi8, #{{.*}}> + %3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1> + // CHECK: tt.dot + // CHECK-NOT: tt.trans + // CHECK: scf.yield + scf.yield %3 : tensor<128x32xf32, #blocked1> + } + // CHECK: arith.truncf + // CHECK: ttg.convert_layout + // CHECK: tt.trans + %1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3> + tt.store %arg3, %2 : tensor<128x32x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir index 72ef11dcafd7..7ed7db0c1e3e 100644 --- a/test/TritonGPU/accumulator-init.mlir +++ b/test/TritonGPU/accumulator-init.mlir @@ -1,50 +1,67 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @constant_init // CHECK-DAG: %[[FALSE:.+]] = arith.constant false -// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] - tt.func @constant_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> } +// CHECK-LABEL: @constant_init_integer +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init_integer(%A: !ttg.memdesc<128x64xi8, #shared, #smem>, %B: !ttg.memdesc<64x16xi8, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0> : tensor<128x16xi32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xi32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xi8, #shared, #smem> * !ttg.memdesc<64x16xi8, #shared1, #smem> -> tensor<128x16xi32, #mma1> + scf.yield %acc: tensor<128x16xi32, #mma1> + } + tt.return %17 : tensor<128x16xi32, #mma1> + } + // CHECK-LABEL: @if_after_mma // CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> // CHECK-DAG: %[[TRUE:.+]] = arith.constant true // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -61,21 +78,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[TRUE]], %[[FALSE]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %acc : tensor<128x16xf32, #mma1> } else { @@ -97,9 +114,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -111,7 +128,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -128,9 +145,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -142,7 +159,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -154,17 +171,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @sel_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -178,9 +195,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) // CHECK: %[[CND:.+]] = arith.cmpi // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @sel_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -188,7 +205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -208,13 +225,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.yield %[[ACC]] // CHECK: else // CHECK: scf.yield %[[ACC]] -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.if %[[CND]] // CHECK: scf.yield %[[C0_TENSOR]] // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_and_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_and_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -226,7 +243,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_0 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_0 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -243,7 +260,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) // CHECK: %[[CND:.+]] = arith.cmpi -// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] // CHECK: scf.yield %[[C0_TENSOR]] // CHECK: else @@ -254,14 +271,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: else // CHECK: scf.yield %[[ACC_CND]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @two_ifs_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @two_ifs_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -280,15 +297,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that we bail out in unsupported cases // CHECK-LABEL: @non_zero_init -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc - tt.func @non_zero_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -296,15 +313,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @zero_init_dist_2 -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc - tt.func @zero_init_dist_2(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg5 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -312,8 +329,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @if_defines_alternative -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc - tt.func @if_defines_alternative(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> @@ -321,7 +338,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -334,18 +351,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } // CHECK-LABEL: @non_cond_override -// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc - tt.func @non_cond_override(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { - %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> } + +// If the condition is a tensor skip the optimization. +// CHECK-LABEL: @negative_sel_tensor +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @negative_sel_tensor(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } } diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir index 7854a4eed7a5..260dddb954d9 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -1,19 +1,19 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=0' | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> // CHECK-LABEL: mfma_dot_fp8e5m2 -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @mfma_dot_fp8e5m2( - %arg0: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - %arg1: tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<128x256x!tt.ptr, #blocked> ) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[A0:.+]] = triton_gpu.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - // CHECK: %[[B0:.+]] = triton_gpu.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> // CHECK: tt.dot %[[A1]], %[[B1]] - %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> tt.return } diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir index 7d3e8c23bed3..b68fe93493ed 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir @@ -1,27 +1,27 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1100 matrix-instruction-size=0' | FileCheck %s -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf32( - // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x256x!tt.ptr, #blocked>) { // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT0_OP_C:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_C]] + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[DOT0_OP_A:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_A]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] - // CHECK: %[[DOT0_OP_B:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_B]] - // CHECK-SAME: -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT0_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> tt.return @@ -30,28 +30,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf16( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -60,32 +60,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_ab8_cf16( - // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x64x!tt.ptr, #blocked>) { // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT2_OP_C:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_C]] + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> - // CHECK: %[[DOT2_OP_A_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_A]] - // CHECK-SAME: -> tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] - // CHECK-SAME: -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 16}>> - // CHECK: %[[DOT2_OP_B_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_B]] - // CHECK-SAME: -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 16}>> // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> tt.return @@ -94,28 +94,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_i8_i32( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -124,26 +124,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fma_dot_i16_i16( - // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT3_ARG_C:.+]] = arith.constant dense<0> : tensor<128x32xi16, #[[DOT_OP_PARENT]]> %3 = arith.constant dense<0> : tensor<128x32xi16, #blocked> // CHECK: %[[DOT3_OP_A:.+]] = arith.sitofp %[[DOT3_ARG_A]] - // CHECK-SAME: to tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] + // CHECK-SAME: to tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_OP_B:.+]] = arith.sitofp %[[DOT3_ARG_B]] - // CHECK-SAME: to tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] + // CHECK-SAME: to tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_OP_C:.+]] = arith.sitofp %[[DOT3_ARG_C]] // CHECK-SAME: to tensor<128x32xf32, #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_FMA_RES:.+]] = tt.dot %[[DOT3_OP_A]], %[[DOT3_OP_B]], %[[DOT3_OP_C]] // CHECK-SAME: -> tensor<128x32xf32, #[[DOT_OP_PARENT]]> - %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> + %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> // CHECK: arith.fptosi %[[DOT3_FMA_RES]] // CHECK-SAME: to tensor<128x32xi16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir index a8683a5d3923..a5bf857dfb84 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir @@ -1,27 +1,27 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1200 matrix-instruction-size=0' | FileCheck %s -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf32( - // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x256x!tt.ptr, #blocked>) { // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT0_OP_C:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_C]] + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> - // CHECK: %[[DOT0_OP_A:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_A]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] - // CHECK: %[[DOT0_OP_B:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_B]] - // CHECK-SAME: -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT0_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> tt.return @@ -30,28 +30,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf16( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return @@ -60,32 +60,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_ab8_cf16( - // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x64x!tt.ptr, #blocked>) { // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT2_OP_C:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_C]] + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> - // CHECK: %[[DOT2_OP_A_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_A]] - // CHECK-SAME: -> tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] - // CHECK-SAME: -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> - // CHECK: %[[DOT2_OP_B_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_B]] - // CHECK-SAME: -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] - // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] - %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> tt.return @@ -94,28 +94,28 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_i8_i32( - // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> - %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, - // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> - %1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<32x32x!tt.ptr, #blocked>) { // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> - // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> - // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] - // CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] - // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] - // CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] - %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> - // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index eda2dd8d9958..ed47e1512da9 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion1 tt.func @conversion1(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -22,8 +22,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion2 tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -50,8 +50,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion3 tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -89,8 +89,48 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // + // This is the same as conversion3, but now the `arith.extsi` operations + // disappeared and all the offsets are 32 bits. + // + // CHECK-LABEL: tt.func @conversion4 + tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + //CHECK: tt.load %[[newPtr]] + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOp tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -135,8 +175,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOp2 tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -181,8 +221,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forNested tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -227,8 +267,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @ifOp tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -269,8 +309,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @whileOp tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -305,8 +345,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @condBranch tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -349,8 +389,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @branch tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -388,30 +428,30 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So // we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform // offset will be A*B+D -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @tile_offset tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { %c128_i32 = arith.constant 128 : i32 %c256_i32 = arith.constant 256 : i32 %1 = tt.get_program_id x : i32 %20 = arith.muli %1, %c256_i32 : i32 - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %24 = tt.splat %20 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %26 = arith.addi %24, %22 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> - %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 - // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> @@ -443,21 +483,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // = (U + N)*U + N // Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) // The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func public @matmul_kernel tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { %c128_i32 = arith.constant 128 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c128_i32 : i32 - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.splat %1 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = arith.addi %3, %2 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> - %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> @@ -469,8 +509,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 // CHECK: %[[makerange:.*]] = tt.make_range - // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> @@ -490,8 +530,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @select tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 @@ -523,8 +563,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1100", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @where_kernel tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ %c0_i8 = arith.constant 0 : i8 @@ -549,8 +589,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @forOpWithHints tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c0 = arith.constant 0: index @@ -580,8 +620,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: scalar_pointers tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %0 = tt.get_program_id x : i32 @@ -608,8 +648,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: @scalar_if tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ %0 = tt.get_program_id x : i32 @@ -638,8 +678,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_while tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ %c1024_i32 = arith.constant 1024 : i32 @@ -667,8 +707,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_cond_branch tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ %c1024_i32 = arith.constant 1024 : i32 diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir new file mode 100644 index 000000000000..18922b15aaef --- /dev/null +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -0,0 +1,124 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: simple + tt.func @simple(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 :i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_load %arg0[%[[offset]]] + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + // CHECK: buffer_load %arg1[%[[offset]]] + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + // CHECK: %[[data:.*]] = arith.addf + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_store %[[data]], %arg2[%[[offset]]] + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset + tt.func @assume_positive_offset(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits + tt.func @offset_64_bits(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits_narrow + tt.func public @offset_64_bits_narrow(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.splat %1: i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset_32_bit:.*]] = arith.trunci + %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> + %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: non_canonical_ptr + tt.func @non_canonical_ptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked>{ + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %arg1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} diff --git a/test/TritonGPU/amd/amd-extractslice-op.mlir b/test/TritonGPU/amd/amd-extractslice-op.mlir new file mode 100644 index 000000000000..bde77b475e2b --- /dev/null +++ b/test/TritonGPU/amd/amd-extractslice-op.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + %72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 000000000000..8cc3ae64f44c --- /dev/null +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,103 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 + +module { + // INSERT_IGLP0-LABEL: @test_dot_op + // INSERT_IGLP1-LABEL: @test_dot_op + // INSTR_COUNT_NS1-LABEL: @test_dot_op + // INSTR_COUNT_NS2-LABEL: @test_dot_op + // LABELING_PS_1-LABEL: @test_dot_op + // LABELING_PS_2-LABEL: @test_dot_op + tt.func @test_dot_op(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + + %a_mask = arith.constant dense : tensor<128x32xi1> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16> + %b_mask = arith.constant dense : tensor<32x128xi1> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32> + + %a_off = arith.constant dense<4> : tensor<128x32xi32> + %b_off = arith.constant dense<4> : tensor<32x128xi32> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32>) { + %a = tt.load %a_ptr : tensor<128x32x!tt.ptr> + %b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr> + + // INSERT_IGLP0: rocdl.iglp.opt 0 + // INSERT_IGLP1: rocdl.iglp.opt 1 + + // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // INSTR_COUNT_NS2: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] + // USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions + + // LABELING_PS_1: scf.for + // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_1: %[[REG1_OP0:.+]] = ttg.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = ttg.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} + + // LABELING_PS_2: scf.for + // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32> + } + + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32> + %c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + + tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr> + tt.return +} +} diff --git a/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/test/TritonGPU/amd/amd-optimize-epilogue.mlir index 8939562d0cb9..8cc467e77337 100644 --- a/test/TritonGPU/amd/amd-optimize-epilogue.mlir +++ b/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -1,17 +1,17 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-epilogue | FileCheck %s // CHECK-LABEL: one_op_in_chain -// CHECK-NOT: triton_gpu.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @one_op_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = arith.truncf %1 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> tt.store %3, %2 : tensor<32x32x!tt.ptr, #blocked> @@ -22,17 +22,17 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: two_ops_in_chain -// CHECK-NOT: triton_gpu.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @two_ops_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = math.exp2 %1 : tensor<32x32xf32, #blocked> %3 = arith.truncf %2 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 686e5a24e8dd..de4c46c794c0 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -7,14 +7,15 @@ // CHECK-LABEL: hoist_q_out_of_the_loop // CHECK: %[[TRUNCF:.+]] = arith.truncf -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK-NEXT: %[[ALLOC:.+]] = ttg.local_alloc %[[TRUNCF]] +// CHECK-NEXT: ttg.local_load %[[ALLOC]] // CHECK: scf.for -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.44269502 : f32 @@ -34,11 +35,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> - %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> - %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} @@ -54,11 +55,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: scf.for // CHECK: %[[TRUNCF:.+]] = arith.truncf // CHECK-NEXT: arith.constant -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 1.44269502 : f32 @@ -78,11 +80,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> - %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> - %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} @@ -91,25 +93,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> - +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load -// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc -// CHECK: triton_gpu.local_store %[[LOAD]], %[[ALLOC]] -// CHECK: triton_gpu.local_load %[[ALLOC]] -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %[[ALLOC:.+]] = ttg.local_alloc +// CHECK: ttg.local_store %[[LOAD]], %[[ALLOC]] +// CHECK: ttg.local_load %[[ALLOC]] +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #shared, mutable> - triton_gpu.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, mutable> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttg.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -167,15 +169,16 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // yield // } -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32, ttg.target = "hip:gfx942"} { // CHECK-LABEL: tt.func @matmul_loop // CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) @@ -189,57 +192,57 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] // CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] // Stage 1 -// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %[[ARG11]] // CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} // CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] // Stage 0.b // CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} // CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} // CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] -// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] // CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] // CHECK: } tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { %20 = arith.subi %arg1, %arg2 : index %21 = arith.cmpi slt, %arg5, %20 : index - %22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> @@ -249,14 +252,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %32 = arith.addi %arg9, %c1_i32 : i32 %33 = arith.cmpi slt, %32, %c1_i32 : i32 %34 = arith.select %33, %32, %c0_i32 : i32 - %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %35 = ttg.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> } - triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> tt.return %19#2 : tensor<128x128xf32, #mma> } @@ -267,30 +270,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tt.func @matmul_loop_mb // CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// Stage 0 -// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} -// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] -// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] -// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] -// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] // Stage 1 -// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} -// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} -// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] -// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]] +// Stage 1 +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]] +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]] +// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] // Stage 2 -// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = ttg.local_load %[[ARG11]] // CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} // CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] -// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] +// CHECK: scf.yield %[[ADDPTR_33]], %[[ADDPTR_39]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]], %[[MEMDESC_SUBVIEW_32]], %[[LOAD_38]], %[[LOAD_41]] // CHECK: } tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { @@ -298,23 +301,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> @@ -328,18 +331,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %25 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { %28 = arith.muli %arg2, %c2 : index %29 = arith.subi %arg1, %28 : index %30 = arith.cmpi slt, %arg5, %29 : index - %31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %31 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> @@ -349,14 +352,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %41 = arith.addi %arg9, %c1_i32 : i32 %42 = arith.cmpi slt, %41, %c2_i32 : i32 %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> } - triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> tt.return %27#2 : tensor<128x128xf32, #mma> } @@ -382,522 +385,100 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] // CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] // Stage 2 -// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] -// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[LOCAL_LOAD_36:.*]] = ttg.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG12]] // CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] // Stage 1.b // CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} // CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} // CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] -// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] // CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] // CHECK: } - tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { %c2 = arith.constant 2 : index %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c1_i32 = arith.constant 1 : i32 - %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %cst_0 = arith.constant dense<1> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> %2 = arith.cmpi sgt, %arg1, %c0 : index - %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %2 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> %5 = arith.cmpi sgt, %arg1, %c1 : index - %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> - %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> - %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %15 = tt.splat %5 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) { %20 = arith.subi %arg1, %c2 : index %21 = arith.cmpi slt, %arg6, %20 : index %22 = arith.subi %arg1, %c1 : index %23 = arith.cmpi slt, %arg6, %22 : index - %24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %24 = ttg.local_load %arg11 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = ttg.local_load %arg12 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> - %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> - %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> %39 = arith.addi %arg10, %c1_i32 : i32 %40 = arith.cmpi slt, %39, %c1_i32 : i32 %41 = arith.select %40, %39, %c0_i32 : i32 - %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %42 = ttg.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> } - triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> tt.return %19#0 : tensor<16x16xf32, #mma> } } -// ----- -// This test ensures that loads will not be moved across `for` loops. - -// CHECK-LABEL: tt.func public @_attn_bwd -// CHECK: tt.load -// CHECK: tt.load -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// Moved before the independent `tt.store` ops but not before the `for` ops. -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.store -// CHECK: tt.store -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// CHECK: tt.store - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c32_i32 = arith.constant 32 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c16_i32 = arith.constant 16 : i32 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> - %0 = tt.get_program_id z : i32 - %1 = arith.muli %0, %arg14 : i32 - %2 = arith.extsi %1 : i32 to i64 - %3 = arith.remsi %0, %arg13 : i32 - %4 = arith.muli %arg11, %3 : i32 - %5 = arith.divsi %0, %arg13 : i32 - %6 = arith.muli %arg10, %5 : i32 - %7 = arith.addi %4, %6 : i32 - %8 = arith.extsi %7 : i32 to i64 - %9 = tt.get_program_id x : i32 - %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 - %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 - %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 - %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 - %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 - %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 - %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 - %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 - %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 - %19 = arith.muli %9, %c128_i32 : i32 - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> - %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> - %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> - %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> - %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> - %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> - %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> - %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> - %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> - %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> - %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> - %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> - %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> - %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> - %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> - %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> - %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> - %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> - %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> - %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> - %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> - %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> - %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> - %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %90 = arith.muli %arg12, %c16_i32 : i32 - %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> - %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> - %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { - %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> - %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> - %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> - %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> - %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %250 = arith.addi %arg18, %c16_i32 : i32 - %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> - } - %94 = arith.addi %19, %c128_i32 : i32 - %95 = arith.subi %arg14, %94 : i32 - %96 = arith.divsi %95, %c32_i32 : i32 - %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> - %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> - %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> - %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> - %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> - %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> - %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> - %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %122 = arith.muli %arg12, %c32_i32 : i32 - %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> - %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> - %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { - %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> - %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> - %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> - %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> - %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> - %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %244 = arith.addi %arg18, %c32_i32 : i32 - %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> - } - %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> - %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> - %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> - %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> - %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> - %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> - %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %174 = arith.muli %arg12, %c16_i32 : i32 - %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> - %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { - %206 = arith.addi %arg20, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> - %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> - %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> - %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> - %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %238 = arith.addi %arg17, %c16_i32 : i32 - %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 - } - triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %178 = arith.divsi %19, %c32_i32 : i32 - %179 = arith.muli %178, %c32_i32 : i32 - %180 = arith.subi %19, %179 : i32 - %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> - %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %196 = arith.muli %arg12, %c32_i32 : i32 - %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> - %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { - %206 = arith.addi %arg19, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> - %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> - %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> - %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> - %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 - } - triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> - %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> - tt.return - } -} - // ----- // CHECK-LABEL: sink_convert_dealloc -// CHECK-COUNT-2: triton_gpu.local_dealloc %{{.+}} : !tt.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-COUNT-2: ttg.local_dealloc %{{.+}} : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> - %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> - triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -908,17 +489,55 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: anchor_barrier // CHECK: gpu.barrier // CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> gpu.barrier %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + %1 = ttg.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> tt.return } } + + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: dont_hoist_scf_ops + // Make sure we don't hoist scf ops above its dependencies. + tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>, + %base: tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, + %p1: tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + // CHECK: scf.for + %54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 { + // CHECK: arith.addi + %f = arith.addi %arg21, %c128_i32 : i32 + // CHECK: scf.if + // CHECK: tt.load + %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ + %t = tt.splat %f : i32 -> tensor<256x128xi32> + %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> + scf.yield %padd : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } else { + scf.yield %base : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } + %l = tt.load %p0 : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %r = tt.load %p1 : tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + scf.yield %acc : tensor<256x128xf32, #mfma> + } + tt.return %54 : tensor<256x128xf32, #mfma> + } +} diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 000000000000..24139f66be5e --- /dev/null +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,256 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// + +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x32xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<32x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x32xf16, #shared, #smem, mutable> -> tensor<256x32xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !ttg.memdesc<256x32xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Category 2: single dot with two loads and tile size is large enough (128x128x128). +// We make sure the move is legal. +// Should NOT apply: the 2nd load has a user before the dot +// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.store +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x128xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x128xf16, #shared1, #smem, mutable> -> tensor<128x128xf16, #dotOp1> + tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + scf.yield %3 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 3: two dots in the for loop. Make sure the optimization is not applied +// should NOT apply: two dots +// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot +// CHECK: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: ttg.local_load +// CHECK-NEXT: ttg.local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store +// CHECK-NEXT: ttg.local_store +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index 0ec2ad5382ec..38f2f21eeef6 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -4,18 +4,19 @@ // Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS // CHECK-LABEL: alloc_convert_load // CHECK-32KLIMIT-LABEL: alloc_convert_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> - %3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -26,18 +27,19 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // in case of relatively small scratch buffer // CHECK-LABEL: alloc_convert_small_load // CHECK-32KLIMIT-LABEL: alloc_convert_small_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -48,18 +50,19 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // in case of relatively small scratch buffer // CHECK-LABEL: alloc_convert_3d_load // CHECK-32KLIMIT-LABEL: alloc_convert_3d_load -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 -// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<1x128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #smem> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -68,22 +71,23 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // Check that optimization triggers with custom LDS limit and do not triggers with default one // CHECK-LABEL: alloc_convert_32k_limit -// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> // CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit -// CHECK-32KLIMIT: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK-32KLIMIT: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 -// CHECK-32KLIMIT: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma -// CHECK-32KLIMIT: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK-32KLIMIT: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK-32KLIMIT: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} { - %1 = triton_gpu.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> - %2 = triton_gpu.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> - %3 = triton_gpu.local_load %1 : !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> + %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> tt.return } } @@ -91,30 +95,51 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // ----- // Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion) -// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> -// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> // CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) -// CHECK: [[ALLOC:%[0-9]+]] = triton_gpu.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !tt.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> -// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = triton_gpu.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> -// CHECK: [[CONVERT_1:%[0-9]+]] = triton_gpu.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> -// CHECK: [[CONVERT_2:%[0-9]+]] = triton_gpu.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> -// CHECK: [[LOAD:%[0-9]+]] = triton_gpu.local_load [[ALLOC]] : !tt.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> -#mma2 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> -#dotop1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> -#dotop2 = #triton_gpu.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem> +// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> +// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> +// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> +// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> +#dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} { - %alloc = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> - %convert_1 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> - %convert_2 = triton_gpu.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> - %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #dotop1> + %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %convert_1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> + %convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #dotop1> + tt.return + } +} + +// ----- + +// Checks that optimization do not crash on 1d tensor +// CHECK-LABEL: convert_1d +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.local_load +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #smem> + %1 = ttg.convert_layout %arg0 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #smem> -> tensor<128x128xf32, #mma> tt.return } } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 9422bb0f8530..b47005b56978 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -3,16 +3,16 @@ // CHECK-LABEL: @test_canonicalize_convert_view // CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder // CHECK: tt.return %[[V]] -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { - %c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> + %c = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } @@ -24,15 +24,15 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> // is an expensive view which would require moving data across threads. // CHECK-LABEL: @test_canonicalize_convert_expensive_view // CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] +// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]] // CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder // CHECK: tt.return %[[V]] -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { - %c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> + %c = ttg.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } @@ -42,18 +42,18 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo // CHECK-LABEL: @test_canonicalize_convert_histogram // CHECK-SAME: (%[[ARG:.+]]: tensor<256xi32 -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[V:.+]] = tt.histogram %[[ARG]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %[[V]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) -> tensor<512xi32, #blocked2> { - %0 = triton_gpu.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> %1 = tt.histogram %0 : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked> - %2 = triton_gpu.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2> tt.return %2 : tensor<512xi32, #blocked2> } } // end module @@ -62,74 +62,78 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) // CHECK-LABEL: @test_canonicalize_convert_local_load // CHECK-NOT: gpu.barrier -// CHECK: %[[V:.+]] = triton_gpu.local_load +// CHECK: %[[V:.+]] = ttg.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: tt.return %[[V]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.compute-capability" = 80} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} { tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<256xi32, #shared, mutable> - %1 = triton_gpu.local_load %0 : !tt.memdesc<256xi32, #shared, mutable> -> tensor<256xi32, #blocked> + %0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked> gpu.barrier - %2 = triton_gpu.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> tt.return %2 : tensor<256xi32, #blocked1> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold1 - tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { - // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc - // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] - // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #ttg.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold2 - tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> { - // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc - // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] - // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> - tt.return %2 : !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared1, #smem> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { // CHECK-LABEL: local_alloc_fold - // CHECK-NEXT: %[[ARG:.+]] = triton_gpu.local_alloc + // CHECK-NEXT: %[[ARG:.+]] = ttg.local_alloc // CHECK-NEXT: tt.return %[[ARG]] - %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> - tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> } } // end module diff --git a/test/TritonGPU/coalesce-async-copy.mlir b/test/TritonGPU/coalesce-async-copy.mlir new file mode 100644 index 000000000000..e0e4f0077b07 --- /dev/null +++ b/test/TritonGPU/coalesce-async-copy.mlir @@ -0,0 +1,37 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s + +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>, + %mask: tensor<64x16xi1, #blocked>, + %other: tensor<64x16xi8, #blocked>) { + %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> + tt.return +} +} + +// ----- + +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>) { + %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> + tt.return +} +} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index cf93c37b840d..25e136514b01 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -1,22 +1,22 @@ // RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - -// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> -// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> -// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: [[row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[load_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> +// CHECK: [[load_other:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> // CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] : tensor<64x64x!tt.ptr, [[row_layout]]> -// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> -// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> -// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> +// CHECK: [[store_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_val:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> +// CHECK: [[store_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, @@ -34,7 +34,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> @@ -42,7 +42,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %17 = triton_gpu.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %19 = tt.load %10, %cst, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> tt.store %18, %19, %cst : tensor<64x64x!tt.ptr, #blocked1> @@ -53,12 +53,12 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { -// CHECK: [[NARROW_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -// CHECK: [[WIDE_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %0 = tt.get_program_id x : i32 @@ -87,11 +87,11 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-NOT: sizePerThread = [4] -// CHECK: #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-NOT: sizePerThread = [4] tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 @@ -124,10 +124,39 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 // COM: Reproducer for issue #3866 // CHECK-LABEL: @test_3866 // CHECK: tt.load {{.*}} : !tt.ptr -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { tt.func public @test_3866(%arg0: !tt.ptr, %arg1: i32, %arg2: i64) { %0 = tt.make_tensor_ptr %arg0, [%arg2, %arg2], [%arg2, %arg2], [%arg1, %arg1] {order = array} : > %1 = tt.load %0 : !tt.ptr> tt.return } } + +// ----- + +// COM: Reproducer for issue #5122 +// CHECK-LABEL: @test_5122 +module { + tt.func public @test_5122(%arg0: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi sgt, %arg0, %c1_i32 : i32 + scf.if %0 { + %1 = scf.if %0 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %2 = arith.cmpi sgt, %1, %c1_i32 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 { + %5 = arith.addi %arg2, %c1_i32 : i32 + scf.yield %5 : i32 + } + } + tt.return + } +} diff --git a/test/TritonGPU/combine-select-if.mlir b/test/TritonGPU/combine-select-if.mlir index 62a9474dcb76..f00b9712358a 100644 --- a/test/TritonGPU/combine-select-if.mlir +++ b/test/TritonGPU/combine-select-if.mlir @@ -1,46 +1,77 @@ // RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @select_if_combine - tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr, #blocked>, %cnd: i1) attributes {noinline = false} { - // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> - %cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked> - // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> - %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked> - // CHECK-NOT: arith.select - %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked> - // CHECK: %[[IF_RES:.*]] = scf.if - scf.if %cnd { - tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr, #blocked> - // CHECK: scf.yield %[[CST0]] - } - // CHECK: else - // CHECK: scf.yield %[[CST1]] - // CHECK: tt.store %{{.*}}, %[[IF_RES]] - tt.store %dst_ptr, %sel : tensor<64x!tt.ptr, #blocked> - tt.return +tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr>, %cnd: i1) { + // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> + // CHECK-NOT: arith.select + %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> + // CHECK: %[[R:.+]] = scf.if %{{.*}} + // CHECK: tt.store %{{.*}}, %{{.*}} + // CHECK: scf.yield %[[CST0]] + // CHECK: } else { + // CHECK: scf.yield %[[CST1]] + // CHECK: } + scf.if %cnd { + tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr> } + // CHECK: tt.store %{{.*}}, %[[R]] + tt.store %dst_ptr, %sel : tensor<64x!tt.ptr> + tt.return } // ----- - // CHECK-LABEL: @if_multiple_sel tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){ -// CHECK-NOT: select -// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } else { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } -// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 + // CHECK-NOT: arith.select %0 = arith.select %arg0, %arg1, %arg2 : i32 %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } else { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } %2 = scf.if %arg0 -> (i32) { %3 = arith.subi %arg1, %arg2 : i32 scf.yield %3 : i32 } else { scf.yield %arg1 : i32 } + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 tt.return %0, %1, %2 : i32, f32, i32 } + +// ----- +// CHECK-LABEL: tt.func @users_in_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32 +tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) { + // CHECK: %[[CST:.*]] = arith.constant 8 : i32 + %c8_i32 = arith.constant 8 : i32 + // CHECK-NOT: arith.select + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) { + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32 + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32 + // CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32 + // CHECK: } else { + // CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32 + // CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32 + // CHECK: } + %2:2 = scf.if %arg0 -> (i32, i32) { + %3 = arith.muli %0, %arg2 : i32 + %4 = arith.addi %0, %c8_i32 : i32 + scf.yield %3, %4 : i32, i32 + } else { + %3 = arith.subi %0, %c8_i32 : i32 + scf.yield %arg1, %3 : i32, i32 + } + // CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32 + tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32 +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 682c1cb3019d..7c956192b171 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,20 +1,20 @@ // RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s -#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#layout3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#layout2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#layout3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { -// CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[$target_layout:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-LABEL: cst tt.func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %cst : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %cst : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %cst : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -22,8 +22,8 @@ tt.func @cst() -> tensor<1024xi32, #layout1> { // CHECK-LABEL: range tt.func @range() -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -31,8 +31,8 @@ tt.func @range() -> tensor<1024xi32, #layout1> { // CHECK-LABEL: splat tt.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> - // CHECK-NOT: triton_gpu.convert_layout + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> tt.return %1: tensor<1024xi32, #layout1> } @@ -42,9 +42,9 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> - %3 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> %4 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> - %5 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %5 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1> tt.return %6: tensor<1024xi32, #layout1> // CHECK: %[[A:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> @@ -59,9 +59,9 @@ tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { tt.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.splat %arg : !tt.ptr -> tensor<1x!tt.ptr, #layout1> %1 = tt.load %0 : tensor<1x!tt.ptr, #layout1> - // CHECK-NOT: triton_gpu.convert_layout - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #layout1> -> tensor<1xi32, #layout0> - %3 = triton_gpu.convert_layout %0 : tensor<1x!tt.ptr, #layout1> -> tensor<1x!tt.ptr, #layout0> + // CHECK-NOT: ttg.convert_layout + %2 = ttg.convert_layout %1 : tensor<1xi32, #layout1> -> tensor<1xi32, #layout0> + %3 = ttg.convert_layout %0 : tensor<1x!tt.ptr, #layout1> -> tensor<1x!tt.ptr, #layout0> tt.store %3, %2 : tensor<1x!tt.ptr, #layout0> tt.return } @@ -72,9 +72,9 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1> %2 = tt.addptr %0, %1 : tensor<16x!tt.ptr, #layout1>, tensor<16xi32, #layout1> %3 = tt.load %2 : tensor<16x!tt.ptr, #layout1> - // CHECK-NOT: triton_gpu.convert_layout - %4 = triton_gpu.convert_layout %3 : tensor<16xi32, #layout1> -> tensor<16xi32, #layout0> - %5 = triton_gpu.convert_layout %2 : tensor<16x!tt.ptr, #layout1> -> tensor<16x!tt.ptr, #layout0> + // CHECK-NOT: ttg.convert_layout + %4 = ttg.convert_layout %3 : tensor<16xi32, #layout1> -> tensor<16xi32, #layout0> + %5 = ttg.convert_layout %2 : tensor<16x!tt.ptr, #layout1> -> tensor<16x!tt.ptr, #layout0> tt.store %5, %4 : tensor<16x!tt.ptr, #layout0> tt.return } @@ -82,71 +82,71 @@ tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { // Hoist the convert on top of ext to make it cheaper. // CHECK-LABEL: hoist_above_ext tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: arith.extf %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %1 = tt.splat %arg1 : f32 -> tensor<1024xf32, #layout0> %2 = arith.addf %0, %1 : tensor<1024xf32, #layout0> - %3 = triton_gpu.convert_layout %2 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %3 = ttg.convert_layout %2 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %3 : tensor<1024xf32, #layout1> } // CHECK-LABEL: hoist_above_ext2 tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: arith.extf %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %1 = tt.splat %arg1 : f16 -> tensor<1024xf16, #layout0> %2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> %3 = arith.addf %0, %2 : tensor<1024xf32, #layout0> - %4 = triton_gpu.convert_layout %3 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %4 = ttg.convert_layout %3 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %4 : tensor<1024xf32, #layout1> } /// CHECK-LABEL: hoist_above_fptofp tt.func @hoist_above_fptofp(%arg0: tensor<1024xf8E4M3FNUZ, #layout0>) -> tensor<1024xf32, #layout1> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: tt.fp_to_fp %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf32, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> tt.return %1 : tensor<1024xf32, #layout1> } /// CHECK-LABEL: dont_hoist_above_trunc_fptofp tt.func @dont_hoist_above_trunc_fptofp(%arg0: tensor<1024xf32, #layout0>) -> tensor<1024xf8E4M3FNUZ, #layout1> { -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: %[[FP8:.+]] = tt.fp_to_fp -// CHECK: triton_gpu.convert_layout %[[FP8]] +// CHECK: ttg.convert_layout %[[FP8]] // CHECK: tt.return %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout0> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout1> + %1 = ttg.convert_layout %0 : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout1> tt.return %1 : tensor<1024xf8E4M3FNUZ, #layout1> } // Hoist the convert on top of broadcast to make it cheaper. // CHECK-LABEL: hoist_above_broadcast tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> { -// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: %[[CVT:.+]] = ttg.convert_layout // CHECK: tt.broadcast %[[CVT]] -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return %0 = tt.broadcast %arg0 : tensor<1024x1xf32, #layout2> -> tensor<1024x128xf32, #layout2> %1 = tt.splat %arg1 : f32 -> tensor<1024x128xf32, #layout2> %2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2> - %3 = triton_gpu.convert_layout %2 : tensor<1024x128xf32, #layout2> -> tensor<1024x128xf32, #layout3> + %3 = ttg.convert_layout %2 : tensor<1024x128xf32, #layout2> -> tensor<1024x128xf32, #layout3> tt.return %3 : tensor<1024x128xf32, #layout3> } // CHECK-LABEL: if tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> %0 = tt.get_program_id x : i32 %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout1> @@ -155,7 +155,7 @@ tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout0> scf.if %4 { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout1> -> tensor<1024xi32, #layout0> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout1> -> tensor<1024xi32, #layout0> tt.store %5, %6 : tensor<1024x!tt.ptr, #layout0> } tt.return @@ -172,12 +172,12 @@ tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> %8 = scf.if %4 -> tensor<1024xi32, #layout1> { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %6 : tensor<1024xi32, #layout1> } else { scf.yield %9 : tensor<1024xi32, #layout1> } - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -195,10 +195,10 @@ tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %8 = scf.if %4 -> tensor<1024xi32, #layout1> { scf.yield %9 : tensor<1024xi32, #layout1> } else { - %7 = triton_gpu.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %7 : tensor<1024xi32, #layout1> } - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -213,15 +213,15 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = %4 = arith.cmpi sgt, %0, %arg0 : i32 %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> %8 = scf.if %4 -> tensor<1024xi32, #layout1> { - %6 = triton_gpu.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %6 : tensor<1024xi32, #layout1> } else { - %7 = triton_gpu.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> scf.yield %7 : tensor<1024xi32, #layout1> } // TODO(csigg): seems like the whole function is converted to layout1. - // disabledCHECK: triton_gpu.convert_layout - // CHECK-NOT: triton_gpu.convert_layout + // disabledCHECK: ttg.convert_layout + // CHECK-NOT: ttg.convert_layout tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> tt.return } @@ -230,27 +230,27 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked0a = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2a = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> - -// CHECK-DAG: [[$row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK-DAG: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -// CHECK-DAG: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked0a = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2a = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked5 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +// CHECK-DAG: [[$row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK-DAG: [[$col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK-DAG: [[$col_layout_novec:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK-LABEL: @transpose -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> - // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout]]> + // CHECK: [[cvt_val:%.*]] = ttg.convert_layout [[loaded_val]] : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout]]> // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64x!tt.ptr, [[$col_layout]]> // CHECK: tt.return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> @@ -265,7 +265,7 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> @@ -273,34 +273,34 @@ tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %17 = triton_gpu.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %19 = triton_gpu.convert_layout %10 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %20 = triton_gpu.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %21 = triton_gpu.convert_layout %cst : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %19 = ttg.convert_layout %10 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %20 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %21 = ttg.convert_layout %cst : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %22 = tt.load %19, %20, %21 : tensor<64x64x!tt.ptr, #blocked3> - %23 = triton_gpu.convert_layout %22 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> - %24 = triton_gpu.convert_layout %18 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked4> - %25 = triton_gpu.convert_layout %23 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked4> - %26 = triton_gpu.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked4> + %23 = ttg.convert_layout %22 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %24 = ttg.convert_layout %18 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked4> + %25 = ttg.convert_layout %23 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked4> + %26 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked4> tt.store %24, %25, %26 : tensor<64x64x!tt.ptr, #blocked4> tt.return } } // CHECK-LABEL: loop -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]> // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> // CHECK-NEXT: } - // CHECK-NOT: triton_gpu.convert_layout - // CHECK: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout_novec]]> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout + // CHECK: {{.*}} = ttg.convert_layout [[loop_ret]]#0 : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout_novec]]> + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> @@ -318,14 +318,14 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { - %23 = triton_gpu.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %25 = triton_gpu.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> - %27 = triton_gpu.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> @@ -336,31 +336,31 @@ tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %18 = triton_gpu.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %20 = triton_gpu.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> - %21 = triton_gpu.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> - %22 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> tt.return } } // CHECK-LABEL: loop_if -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.for -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.if -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.yield // CHECK: else // CHECK: scf.yield -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.yield -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.convert_layout +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.store -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { %cst = arith.constant dense : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> @@ -379,16 +379,16 @@ tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i3 %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { %33 = arith.cmpi "sgt", %arg5, %c0 : index %34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) { - %23 = triton_gpu.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> - %25 = triton_gpu.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> - %27 = triton_gpu.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> scf.yield %27 : tensor<64x64xf32, #blocked1> } else { scf.yield %arg6 : tensor<64x64xf32, #blocked1> @@ -403,20 +403,20 @@ tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i3 %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> - %18 = triton_gpu.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> - %20 = triton_gpu.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> - %21 = triton_gpu.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> - %22 = triton_gpu.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> tt.return } } // CHECK-LABEL: vecadd -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -432,15 +432,15 @@ tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %11 = arith.addi %4, %5 : tensor<256xi32, #blocked5> %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> %13 = tt.load %12 : tensor<256x!tt.ptr, #blocked5> - %14 = triton_gpu.convert_layout %13 : tensor<256xf32, #blocked5> -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %14 = ttg.convert_layout %13 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> %16 = tt.load %15 : tensor<256x!tt.ptr, #blocked5> - %17 = triton_gpu.convert_layout %16 : tensor<256xf32, #blocked5> -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> - %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %17 = ttg.convert_layout %16 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %18 = arith.addf %14, %17 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> %19 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked5> %20 = arith.addi %2, %3 : tensor<256xi32, #blocked5> %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> - %22 = triton_gpu.convert_layout %18 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<256xf32, #blocked5> + %22 = ttg.convert_layout %18 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<256xf32, #blocked5> tt.store %21, %22 : tensor<256x!tt.ptr, #blocked5> tt.return } @@ -448,9 +448,9 @@ tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr // Select has args with different element types // CHECK-LABEL: select -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> %c512 = arith.constant 512 : index @@ -460,15 +460,15 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2> %0 = tt.get_program_id x : i32 %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0> - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #blocked0> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1> - %4 = triton_gpu.convert_layout %3 : tensor<1x1xi32, #blocked1> -> tensor<1x1xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked0> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked1> -> tensor<1x1xi32, #blocked2> %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked2> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2> %7 = arith.cmpi "slt", %6, %cst_1 : tensor<1x1xi32, #blocked2> %8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0> - %9 = triton_gpu.convert_layout %8 : tensor<512xi32, #blocked0> -> tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x512xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<512xi32, #blocked0> -> tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x512xi32, #blocked2> %11 = arith.muli %6, %cst : tensor<1x1xi32, #blocked2> %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2> %13 = tt.splat %arg0 : !tt.ptr -> tensor<1x512x!tt.ptr, #blocked2> @@ -481,17 +481,17 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2> %21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> %22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2> - %23 = triton_gpu.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> - %24 = triton_gpu.convert_layout %22 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + %23 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %22 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> %25 = tt.load %23, %24 : tensor<1x512x!tt.ptr, #blocked3> - %26 = triton_gpu.convert_layout %25 : tensor<1x512xf64, #blocked3> -> tensor<1x512xf64, #blocked2> + %26 = ttg.convert_layout %25 : tensor<1x512xf64, #blocked3> -> tensor<1x512xf64, #blocked2> %27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2> %28 = arith.cmpf "olt", %arg4, %26 : tensor<1x512xf64, #blocked2> %29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2> %30 = arith.select %29, %26, %arg4 : tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2> - %31 = triton_gpu.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> - %32 = triton_gpu.convert_layout %30 : tensor<1x512xf64, #blocked2> -> tensor<1x512xf64, #blocked3> - %33 = triton_gpu.convert_layout %27 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + %31 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %32 = ttg.convert_layout %30 : tensor<1x512xf64, #blocked2> -> tensor<1x512xf64, #blocked3> + %33 = ttg.convert_layout %27 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> tt.store %31, %32, %33 : tensor<1x512x!tt.ptr, #blocked3> scf.yield %30 : tensor<1x512xf64, #blocked2> } @@ -501,7 +501,7 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr // Make sure the following IR doesn't hang the compiler. // CHECK-LABEL: long_func -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> @@ -529,22 +529,22 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %5 = arith.cmpi "slt", %4, %cst_11 : tensor<1024xi32, #blocked0> %6 = tt.splat %arg5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %8 = triton_gpu.convert_layout %7 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %9 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %8 = ttg.convert_layout %7 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %9 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> %10 = tt.load %8, %9 : tensor<1024x!tt.ptr, #blocked0a> - %11 = triton_gpu.convert_layout %10 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %11 = ttg.convert_layout %10 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> %12 = tt.splat %arg7 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %14 = triton_gpu.convert_layout %13 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked2a> - %15 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked2a> + %14 = ttg.convert_layout %13 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked2a> + %15 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked2a> %16 = tt.load %14, %15 : tensor<1024x!tt.ptr, #blocked2a> - %17 = triton_gpu.convert_layout %16 : tensor<1024xi64, #blocked2a> -> tensor<1024xi64, #blocked0> + %17 = ttg.convert_layout %16 : tensor<1024xi64, #blocked2a> -> tensor<1024xi64, #blocked0> %18 = tt.splat %arg8 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %19 = tt.addptr %18, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %20 = triton_gpu.convert_layout %19 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %21 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %20 = ttg.convert_layout %19 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %21 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> %22 = tt.load %20, %21 : tensor<1024x!tt.ptr, #blocked0a> - %23 = triton_gpu.convert_layout %22 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %23 = ttg.convert_layout %22 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> %24 = arith.subf %cst_13, %11 : tensor<1024xf32, #blocked0> %25 = math.exp %24 : tensor<1024xf32, #blocked0> %26 = arith.sitofp %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xf32, #blocked0> @@ -575,7 +575,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %51 = arith.select %50, %49, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %52 = tt.splat %arg6 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %53 = tt.addptr %52, %51 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %54 = triton_gpu.convert_layout %53 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %54 = ttg.convert_layout %53 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %55 = tt.load %54 : tensor<1024x!tt.ptr, #blocked0> %56 = arith.cmpf "oge", %55, %35 :tensor<1024xf32, #blocked0> %57 = arith.cmpi "eq", %56, %cst_5 : tensor<1024xi1, #blocked0> @@ -597,7 +597,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %73 = arith.addi %60, %72 : tensor<1024xi32, #blocked0> %74 = arith.select %63, %73, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %75 = tt.addptr %52, %74 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %76 = triton_gpu.convert_layout %75 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %76 = ttg.convert_layout %75 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %77 = tt.load %76 : tensor<1024x!tt.ptr, #blocked0> %78 = arith.cmpf "oge", %77, %35 :tensor<1024xf32, #blocked0> %79 = arith.cmpi "eq", %78, %cst_5 : tensor<1024xi1, #blocked0> @@ -619,7 +619,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %95 = arith.addi %82, %94 : tensor<1024xi32, #blocked0> %96 = arith.select %85, %95, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %97 = tt.addptr %52, %96 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %98 = triton_gpu.convert_layout %97 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %98 = ttg.convert_layout %97 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %99 = tt.load %98 : tensor<1024x!tt.ptr, #blocked0> %100 = arith.cmpf "oge", %99, %35 : tensor<1024xf32, #blocked0> %101 = arith.cmpi "eq", %100, %cst_5 : tensor<1024xi1, #blocked0> @@ -641,7 +641,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %117 = arith.addi %104, %116 : tensor<1024xi32, #blocked0> %118 = arith.select %107, %117, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %119 = tt.addptr %52, %118 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %120 = triton_gpu.convert_layout %119 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %120 = ttg.convert_layout %119 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %121 = tt.load %120 : tensor<1024x!tt.ptr, #blocked0> %122 = arith.cmpf "oge", %121, %35 : tensor<1024xf32, #blocked0> %123 = arith.cmpi "eq", %122, %cst_5 : tensor<1024xi1, #blocked0> @@ -663,7 +663,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %139 = arith.addi %126, %138 : tensor<1024xi32, #blocked0> %140 = arith.select %129, %139, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %141 = tt.addptr %52, %140 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %142 = triton_gpu.convert_layout %141 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %142 = ttg.convert_layout %141 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %143 = tt.load %142 : tensor<1024x!tt.ptr, #blocked0> %144 = arith.cmpf "oge", %143, %35 : tensor<1024xf32, #blocked0> %145 = arith.cmpi "eq", %144, %cst_5 : tensor<1024xi1, #blocked0> @@ -685,7 +685,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %161 = arith.addi %148, %160 : tensor<1024xi32, #blocked0> %162 = arith.select %151, %161, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %163 = tt.addptr %52, %162 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %164 = triton_gpu.convert_layout %163 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %164 = ttg.convert_layout %163 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %165 = tt.load %164 : tensor<1024x!tt.ptr, #blocked0> %166 = arith.cmpf "oge", %165, %35 : tensor<1024xf32, #blocked0> %167 = arith.cmpi "eq", %166, %cst_5 : tensor<1024xi1, #blocked0> @@ -707,7 +707,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %183 = arith.addi %170, %182 : tensor<1024xi32, #blocked0> %184 = arith.select %173, %183, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %185 = tt.addptr %52, %184 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %186 = triton_gpu.convert_layout %185 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %186 = ttg.convert_layout %185 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %187 = tt.load %186 : tensor<1024x!tt.ptr, #blocked0> %188 = arith.cmpf "oge", %187, %35 : tensor<1024xf32, #blocked0> %189 = arith.cmpi "eq", %188, %cst_5 : tensor<1024xi1, #blocked0> @@ -729,7 +729,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %205 = arith.addi %192, %204 : tensor<1024xi32, #blocked0> %206 = arith.select %195, %205, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %207 = tt.addptr %52, %206 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %208 = triton_gpu.convert_layout %207 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %208 = ttg.convert_layout %207 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %209 = tt.load %208 : tensor<1024x!tt.ptr, #blocked0> %210 = arith.cmpf "oge", %209, %35 :tensor<1024xf32, #blocked0> %211 = arith.cmpi "eq", %210, %cst_5 : tensor<1024xi1, #blocked0> @@ -751,7 +751,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %227 = arith.addi %214, %226 : tensor<1024xi32, #blocked0> %228 = arith.select %217, %227, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %229 = tt.addptr %52, %228 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %230 = triton_gpu.convert_layout %229 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %230 = ttg.convert_layout %229 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %231 = tt.load %230 : tensor<1024x!tt.ptr, #blocked0> %232 = arith.cmpf "oge", %231, %35 : tensor<1024xf32, #blocked0> %233 = arith.cmpi "eq", %232, %cst_5 : tensor<1024xi1, #blocked0> @@ -773,7 +773,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %249 = arith.addi %236, %248 : tensor<1024xi32, #blocked0> %250 = arith.select %239, %249, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %251 = tt.addptr %52, %250 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %252 = triton_gpu.convert_layout %251 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %252 = ttg.convert_layout %251 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %253 = tt.load %252 : tensor<1024x!tt.ptr, #blocked0> %254 = arith.cmpf "oge", %253, %35 : tensor<1024xf32, #blocked0> %255 = arith.cmpi "eq", %254, %cst_5 : tensor<1024xi1, #blocked0> @@ -795,7 +795,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %271 = arith.addi %258, %270 : tensor<1024xi32, #blocked0> %272 = arith.select %261, %271, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %273 = tt.addptr %52, %272 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %274 = triton_gpu.convert_layout %273 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %274 = ttg.convert_layout %273 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %275 = tt.load %274 : tensor<1024x!tt.ptr, #blocked0> %276 = arith.cmpf "oge", %275, %35 : tensor<1024xf32, #blocked0> %277 = arith.cmpi "eq", %276, %cst_5 : tensor<1024xi1, #blocked0> @@ -817,7 +817,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %293 = arith.addi %280, %292 : tensor<1024xi32, #blocked0> %294 = arith.select %283, %293, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %295 = tt.addptr %52, %294 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %296 = triton_gpu.convert_layout %295 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %296 = ttg.convert_layout %295 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %297 = tt.load %296 : tensor<1024x!tt.ptr, #blocked0> %298 = arith.cmpf "oge", %297, %35 :tensor<1024xf32, #blocked0> %299 = arith.cmpi "eq", %298, %cst_5 : tensor<1024xi1, #blocked0> @@ -842,13 +842,13 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %318 = arith.extsi %317 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> %319 = tt.splat %arg9 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %320 = tt.addptr %319, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %321 = triton_gpu.convert_layout %320 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %321 = ttg.convert_layout %320 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %322 = tt.load %321 : tensor<1024x!tt.ptr, #blocked0> %323 = arith.extf %cst_2 : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> %324 = arith.cmpf "ogt", %322, %323 : tensor<1024xf64, #blocked0> %325 = tt.splat %arg10 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %326 = tt.addptr %325, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %327 = triton_gpu.convert_layout %326 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %327 = ttg.convert_layout %326 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %328 = tt.load %327 : tensor<1024x!tt.ptr, #blocked0> %329 = arith.divf %328, %322 : tensor<1024xf64, #blocked0> %330 = arith.truncf %329 : tensor<1024xf64, #blocked0> to tensor<1024xf32, #blocked0> @@ -857,41 +857,41 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %333 = arith.addf %331, %332 : tensor<1024xf32, #blocked0> %334 = arith.select %324, %333, %35 : tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0> %335 = tt.addptr %319, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %336 = triton_gpu.convert_layout %335 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %336 = ttg.convert_layout %335 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %337 = tt.load %336 : tensor<1024x!tt.ptr, #blocked0> %338 = arith.extf %cst : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> %339 = arith.mulf %337, %338 : tensor<1024xf64, #blocked0> %340 = tt.addptr %325, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %341 = triton_gpu.convert_layout %340 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %341 = ttg.convert_layout %340 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> %342 = tt.load %341 : tensor<1024x!tt.ptr, #blocked0> %343 = arith.mulf %342, %338 : tensor<1024xf64, #blocked0> %344 = tt.splat %arg11 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %345 = tt.addptr %344, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %346 = triton_gpu.convert_layout %345 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %347 = triton_gpu.convert_layout %28 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> - %348 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %346 = ttg.convert_layout %345 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %347 = ttg.convert_layout %28 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %348 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %346, %347, %348 : tensor<1024x!tt.ptr, #blocked0a> %349 = tt.splat %arg12 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %350 = tt.addptr %349, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %351 = triton_gpu.convert_layout %350 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %352 = triton_gpu.convert_layout %317 : tensor<1024xi32, #blocked0> -> tensor<1024xi32, #blocked0a> - %353 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %351 = ttg.convert_layout %350 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %352 = ttg.convert_layout %317 : tensor<1024xi32, #blocked0> -> tensor<1024xi32, #blocked0a> + %353 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %351, %352, %353 : tensor<1024x!tt.ptr, #blocked0a> %354 = tt.splat %arg13 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %355 = tt.addptr %354, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> - %356 = triton_gpu.convert_layout %355 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> - %357 = triton_gpu.convert_layout %334 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> - %358 = triton_gpu.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %356 = ttg.convert_layout %355 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %357 = ttg.convert_layout %334 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %358 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> tt.store %356, %357, %358 : tensor<1024x!tt.ptr, #blocked0a> %359 = tt.splat %arg14 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %360 = tt.addptr %359, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %361 = triton_gpu.convert_layout %360 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> - %362 = triton_gpu.convert_layout %339 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + %361 = ttg.convert_layout %360 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %362 = ttg.convert_layout %339 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> tt.store %361, %362 : tensor<1024x!tt.ptr, #blocked0> %363 = tt.splat %arg15 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> %364 = tt.addptr %363, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> - %365 = triton_gpu.convert_layout %364 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> - %366 = triton_gpu.convert_layout %343 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + %365 = ttg.convert_layout %364 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %366 = ttg.convert_layout %343 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> tt.store %365, %366 : tensor<1024x!tt.ptr, #blocked0> tt.return } @@ -900,9 +900,9 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg // A mnist model from torch inductor. // Check if topological sort is working correct and there's no unnecessary convert // CHECK-LABEL: mnist -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> %c16_i32 = arith.constant 16 : i32 @@ -913,30 +913,30 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c16_i32 : i32 %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0> - %3 = triton_gpu.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<16x1xi32, #blocked1> -> tensor<16x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<16x1xi32, #blocked1> -> tensor<16x1xi32, #blocked2> %6 = tt.splat %1 : i32 -> tensor<16x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<16x1xi32, #blocked2> %8 = arith.cmpi "slt", %7, %cst_1 : tensor<16x1xi32, #blocked2> - %9 = triton_gpu.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %9 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %11 = arith.cmpi "slt", %10, %cst_0 : tensor<1x16xi32, #blocked3> %12 = arith.muli %7, %cst : tensor<16x1xi32, #blocked2> %13 = tt.broadcast %10 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %14 = triton_gpu.convert_layout %13 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked2> + %14 = ttg.convert_layout %13 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked2> %15 = tt.broadcast %12 : tensor<16x1xi32, #blocked2> -> tensor<16x16xi32, #blocked2> %16 = arith.addi %14, %15 : tensor<16x16xi32, #blocked2> %17 = tt.splat %arg0 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> %18 = tt.addptr %17, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> %19 = tt.broadcast %11 : tensor<1x16xi1, #blocked3> -> tensor<16x16xi1, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<16x16xi1, #blocked3> -> tensor<16x16xi1, #blocked2> + %20 = ttg.convert_layout %19 : tensor<16x16xi1, #blocked3> -> tensor<16x16xi1, #blocked2> %21 = tt.broadcast %8 : tensor<16x1xi1, #blocked2> -> tensor<16x16xi1, #blocked2> %22 = arith.andi %20, %21 : tensor<16x16xi1, #blocked2> - %23 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %24 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %23 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %24 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %25 = tt.load %23, %24 : tensor<16x16x!tt.ptr, #blocked4> - %26 = triton_gpu.convert_layout %25 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %26 = ttg.convert_layout %25 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %27 = arith.cmpf "olt", %cst_2, %26 : tensor<16x16xf32, #blocked2> %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> %29 = arith.select %28, %26, %cst_2 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> @@ -944,17 +944,17 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! ^bb0(%arg4: f32, %arg5: f32): %max = arith.maximumf %arg4, %arg5 : f32 tt.reduce.return %max : f32 - }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %31 = triton_gpu.convert_layout %30 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> - %32 = triton_gpu.convert_layout %31 : tensor<16xf32, #blocked0> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> - %34 = triton_gpu.convert_layout %33 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %31 = ttg.convert_layout %30 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %32 = ttg.convert_layout %31 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %34 = ttg.convert_layout %33 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> %35 = arith.sitofp %cst_4 : tensor<16x16xi32, #blocked2> to tensor<16x16xf32, #blocked2> %36 = arith.addf %35, %cst_3 : tensor<16x16xf32, #blocked2> - %37 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %38 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %37 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %38 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %39 = tt.load %37, %38 : tensor<16x16x!tt.ptr, #blocked4> - %40 = triton_gpu.convert_layout %39 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %40 = ttg.convert_layout %39 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %41 = tt.broadcast %34 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> %42 = arith.subf %40, %41 : tensor<16x16xf32, #blocked2> %43 = math.exp %42 : tensor<16x16xf32, #blocked2> @@ -964,24 +964,24 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! ^bb0(%arg4: f32, %arg5: f32): %add = arith.addf %arg4, %arg5 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %47 = triton_gpu.convert_layout %46 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> - %48 = triton_gpu.convert_layout %47 : tensor<16xf32, #blocked0> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> - %50 = triton_gpu.convert_layout %49 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> - %51 = triton_gpu.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %52 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %47 = ttg.convert_layout %46 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %48 = ttg.convert_layout %47 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %50 = ttg.convert_layout %49 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + %51 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %52 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> %53 = tt.load %51, %52 : tensor<16x16x!tt.ptr, #blocked4> - %54 = triton_gpu.convert_layout %53 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %54 = ttg.convert_layout %53 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> %55 = arith.subf %54, %41 : tensor<16x16xf32, #blocked2> %56 = math.log %50 : tensor<16x1xf32, #blocked2> %57 = tt.broadcast %56 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> %58 = arith.subf %55, %57 : tensor<16x16xf32, #blocked2> %59 = tt.splat %arg1 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> %60 = tt.addptr %59, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> - %61 = triton_gpu.convert_layout %60 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> - %62 = triton_gpu.convert_layout %58 : tensor<16x16xf32, #blocked2> -> tensor<16x16xf32, #blocked4> - %63 = triton_gpu.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %61 = ttg.convert_layout %60 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %62 = ttg.convert_layout %58 : tensor<16x16xf32, #blocked2> -> tensor<16x16xf32, #blocked4> + %63 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> tt.store %61, %62, %63 : tensor<16x16x!tt.ptr, #blocked4> tt.return } @@ -989,15 +989,15 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! // ----- -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> // cmpf and cmpi have different operands and result types // CHECK-LABEL: cmp -module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c64 = arith.constant 64 : index %c2048 = arith.constant 2048 : index @@ -1014,14 +1014,14 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> - %3 = triton_gpu.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<64x1xi32, #blocked1> -> tensor<64x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<64x1xi32, #blocked1> -> tensor<64x1xi32, #blocked2> %6 = tt.splat %1 : i32 -> tensor<64x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<64x1xi32, #blocked2> %8 = arith.cmpi "slt", %7, %cst_5 : tensor<64x1xi32, #blocked2> - %9 = triton_gpu.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x64xi32, #blocked3> + %9 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x64xi32, #blocked3> %11 = arith.remsi %7, %cst_4 : tensor<64x1xi32, #blocked2> %12 = arith.divsi %7, %cst_4 : tensor<64x1xi32, #blocked2> %13 = arith.sitofp %cst_3 : tensor<64x64xi32, #blocked2> to tensor<64x64xf32, #blocked2> @@ -1042,24 +1042,24 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> - %49 = triton_gpu.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %16 : tensor<64x64xi32, #blocked2> %51 = tt.addptr %17, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> - %53 = triton_gpu.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> %54 = arith.andi %53, %18 : tensor<64x64xi1, #blocked2> - %55 = triton_gpu.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %56 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> %60 = arith.addi %49, %20 : tensor<64x64xi32, #blocked2> %61 = arith.addi %60, %23 : tensor<64x64xi32, #blocked2> %62 = tt.addptr %24, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %63 = triton_gpu.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %64 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> - %66 = triton_gpu.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> @@ -1074,11 +1074,11 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt ^bb0(%arg8: f32, %arg9: f32): %add = arith.addf %arg8, %arg9 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %27 = triton_gpu.convert_layout %26 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64xf32, #blocked0> - %28 = triton_gpu.convert_layout %27 : tensor<64xf32, #blocked0> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %29 = tt.expand_dims %28 {axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> - %30 = triton_gpu.convert_layout %29 : tensor<64x1xf32, #blocked1> -> tensor<64x1xf32, #blocked2> + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %27 = ttg.convert_layout %26 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64xf32, #blocked0> + %28 = ttg.convert_layout %27 : tensor<64xf32, #blocked0> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %29 = tt.expand_dims %28 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> + %30 = ttg.convert_layout %29 : tensor<64x1xf32, #blocked1> -> tensor<64x1xf32, #blocked2> %31 = arith.muli %7, %cst_4 : tensor<64x1xi32, #blocked2> %32 = tt.broadcast %31 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> %33 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> @@ -1098,24 +1098,24 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> - %49 = triton_gpu.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %32 : tensor<64x64xi32, #blocked2> %51 = tt.addptr %33, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> - %53 = triton_gpu.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> %54 = arith.andi %53, %34 : tensor<64x64xi1, #blocked2> - %55 = triton_gpu.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %56 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> %60 = arith.addi %49, %36 : tensor<64x64xi32, #blocked2> %61 = arith.addi %60, %39 : tensor<64x64xi32, #blocked2> %62 = tt.addptr %40, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %63 = triton_gpu.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %64 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> - %66 = triton_gpu.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> @@ -1124,15 +1124,15 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %72 = math.exp %71 : tensor<64x64xf32, #blocked2> %73 = arith.divf %72, %41 : tensor<64x64xf32, #blocked2> %74 = tt.addptr %42, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %75 = triton_gpu.convert_layout %74 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> - %76 = triton_gpu.convert_layout %73 : tensor<64x64xf32, #blocked2> -> tensor<64x64xf32, #blocked5> - %77 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %75 = ttg.convert_layout %74 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %76 = ttg.convert_layout %73 : tensor<64x64xf32, #blocked2> -> tensor<64x64xf32, #blocked5> + %77 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> tt.store %75, %76, %77 : tensor<64x64x!tt.ptr, #blocked5> %78 = tt.addptr %43, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> %79 = arith.truncf %73 : tensor<64x64xf32, #blocked2> to tensor<64x64xf16, #blocked2> - %80 = triton_gpu.convert_layout %78 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> - %81 = triton_gpu.convert_layout %79 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #blocked4> - %82 = triton_gpu.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %80 = ttg.convert_layout %78 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %81 = ttg.convert_layout %79 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #blocked4> + %82 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> tt.store %80, %81, %82 : tensor<64x64x!tt.ptr, #blocked4> } tt.return @@ -1143,9 +1143,9 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt // Just make sure it doesn't crash on non-tensor types. // CHECK-LABEL: if_no_tensor -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c-1_i64 = arith.constant -1 : i64 %cst = arith.constant 0.000000e+00 : f32 %c-1_i32 = arith.constant -1 : i32 @@ -1173,35 +1173,35 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % // Check if the SimplifyReduceCvt rewriter pattern doesn't hang. // CHECK-LABEL: reduce_cvt -// CHECK-NOT: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) { %cst = arith.constant dense<0> : tensor<1x2xi32, #blocked> %cst_0 = arith.constant dense<2> : tensor<1x2xi32, #blocked> %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1> - %1 = triton_gpu.convert_layout %0 : tensor<2xi32, #blocked1> -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2xi32, #blocked> + %1 = ttg.convert_layout %0 : tensor<2xi32, #blocked1> -> tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2xi32, #blocked> %3 = arith.cmpi "slt", %2, %cst_0 : tensor<1x2xi32, #blocked> %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.reduce.return %add : i32 - }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = triton_gpu.convert_layout %4 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> - %6 = triton_gpu.convert_layout %5 : tensor<1xi32, #blocked1> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> - %8 = triton_gpu.convert_layout %7 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = ttg.convert_layout %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %6 = ttg.convert_layout %5 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %8 = ttg.convert_layout %7 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> %9 = tt.splat %arg0 : !tt.ptr -> tensor<1x2x!tt.ptr, #blocked> %10 = tt.addptr %9, %2 : tensor<1x2x!tt.ptr, #blocked>, tensor<1x2xi32, #blocked> %11 = tt.broadcast %8 : tensor<1x1xi32, #blocked> -> tensor<1x2xi32, #blocked> %12 = arith.extsi %11 : tensor<1x2xi32, #blocked> to tensor<1x2xi64, #blocked> - %13 = triton_gpu.convert_layout %10 : tensor<1x2x!tt.ptr, #blocked> -> tensor<1x2x!tt.ptr, #blocked3> - %14 = triton_gpu.convert_layout %12 : tensor<1x2xi64, #blocked> -> tensor<1x2xi64, #blocked3> - %15 = triton_gpu.convert_layout %3 : tensor<1x2xi1, #blocked> -> tensor<1x2xi1, #blocked3> + %13 = ttg.convert_layout %10 : tensor<1x2x!tt.ptr, #blocked> -> tensor<1x2x!tt.ptr, #blocked3> + %14 = ttg.convert_layout %12 : tensor<1x2xi64, #blocked> -> tensor<1x2xi64, #blocked3> + %15 = ttg.convert_layout %3 : tensor<1x2xi1, #blocked> -> tensor<1x2xi1, #blocked3> tt.store %13, %14, %15 : tensor<1x2x!tt.ptr, #blocked3> tt.return } @@ -1211,19 +1211,19 @@ module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: reduce_cvt2 // Match the reduction -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.reduce // CHECK-SAME: axis = 1 -// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #{{.*}}}>> -// CHECK: triton_gpu.convert_layout +// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #{{.*}}}>> +// CHECK: ttg.convert_layout // CHECK: tt.expand_dims -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked> %c3136_i32 = arith.constant 3136 : index @@ -1237,15 +1237,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked> %0 = tt.get_program_id x : i32 %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<1xi32, #blocked1> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> - %4 = triton_gpu.convert_layout %3 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked> %7 = arith.cmpi "slt", %6, %cst_5 : tensor<1x1xi32, #blocked> %8 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %9 = triton_gpu.convert_layout %8 : tensor<256xi32, #blocked1> -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %9 = ttg.convert_layout %8 : tensor<256xi32, #blocked1> -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> %11 = arith.muli %6, %cst_2 : tensor<1x1xi32, #blocked> %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked> %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked> @@ -1262,11 +1262,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %50 = arith.addi %48, %49 : tensor<1x256xi32, #blocked> %51 = tt.addptr %13, %50 : tensor<1x256x!tt.ptr, #blocked>, tensor<1x256xi32, #blocked> %52 = arith.andi %45, %14 : tensor<1x256xi1, #blocked> - %53 = triton_gpu.convert_layout %51 : tensor<1x256x!tt.ptr, #blocked> -> tensor<1x256x!tt.ptr, #blocked3> - %54 = triton_gpu.convert_layout %52 : tensor<1x256xi1, #blocked> -> tensor<1x256xi1, #blocked3> - %55 = triton_gpu.convert_layout %cst : tensor<1x256xf32, #blocked> -> tensor<1x256xf32, #blocked3> + %53 = ttg.convert_layout %51 : tensor<1x256x!tt.ptr, #blocked> -> tensor<1x256x!tt.ptr, #blocked3> + %54 = ttg.convert_layout %52 : tensor<1x256xi1, #blocked> -> tensor<1x256xi1, #blocked3> + %55 = ttg.convert_layout %cst : tensor<1x256xf32, #blocked> -> tensor<1x256xf32, #blocked3> %56 = tt.load %53, %54, %55 : tensor<1x256x!tt.ptr, #blocked3> - %57 = triton_gpu.convert_layout %56 : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #blocked> + %57 = ttg.convert_layout %56 : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #blocked> %58 = arith.addf %arg6, %57 : tensor<1x256xf32, #blocked> %59 = arith.select %52, %58, %arg6 : tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked> scf.yield %59 : tensor<1x256xf32, #blocked> @@ -1276,17 +1276,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %add = arith.addf %arg7, %arg8 : f32 tt.reduce.return %add : f32 - }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = triton_gpu.convert_layout %16 : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> - %18 = triton_gpu.convert_layout %17 : tensor<1xf32, #blocked1> -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %19 = tt.expand_dims %18 {axis = 1 : i32} : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xf32, #blocked2> - %20 = triton_gpu.convert_layout %19 : tensor<1x1xf32, #blocked2> -> tensor<1x1xf32, #blocked> + }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.convert_layout %16 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.expand_dims %18 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xf32, #blocked2> + %20 = ttg.convert_layout %19 : tensor<1x1xf32, #blocked2> -> tensor<1x1xf32, #blocked> %21 = arith.divf %20, %cst_0 : tensor<1x1xf32, #blocked> %22 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> %23 = tt.addptr %22, %6 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> - %24 = triton_gpu.convert_layout %23 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> - %25 = triton_gpu.convert_layout %21 : tensor<1x1xf32, #blocked> -> tensor<1x1xf32, #blocked> - %26 = triton_gpu.convert_layout %7 : tensor<1x1xi1, #blocked> -> tensor<1x1xi1, #blocked> + %24 = ttg.convert_layout %23 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> + %25 = ttg.convert_layout %21 : tensor<1x1xf32, #blocked> -> tensor<1x1xf32, #blocked> + %26 = ttg.convert_layout %7 : tensor<1x1xi1, #blocked> -> tensor<1x1xi1, #blocked> tt.store %24, %25, %26 : tensor<1x1x!tt.ptr, #blocked> tt.return } @@ -1296,12 +1296,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // Ensure that RematerializeForward doesn't apply when a convert has multiple uses // CHECK-LABEL: loop_convert_multi_uses -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @loop_convert_multi_uses(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<16xf32, #blocked> %c1_i32 = arith.constant 1 : i32 @@ -1322,16 +1322,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %8 = arith.muli %2, %arg3 : i32 %9 = arith.muli %3, %arg4 : i32 %10 = arith.addi %8, %9 : i32 - %11 = triton_gpu.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> - %13 = triton_gpu.convert_layout %12 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %11 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %13 = ttg.convert_layout %12 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> %14 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> %15 = arith.muli %13, %14 : tensor<16x1xi32, #blocked1> - %16 = triton_gpu.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %16 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %18 = tt.broadcast %15 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> %19 = tt.broadcast %17 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> %21 = arith.addi %18, %20 : tensor<16x16xi32, #blocked1> %22 = tt.splat %arg2 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked1> %23 = arith.cmpi "slt", %13, %cst_3 : tensor<16x1xi32, #blocked1> @@ -1352,26 +1352,26 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %62 = tt.splat %61 : i32 -> tensor<16x16xi32, #blocked1> %63 = arith.addi %62, %21 : tensor<16x16xi32, #blocked1> %64 = tt.addptr %22, %63 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %65 = triton_gpu.convert_layout %64 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> - %66 = triton_gpu.convert_layout %24 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> - %67 = triton_gpu.convert_layout %25 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %65 = ttg.convert_layout %64 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %66 = ttg.convert_layout %24 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + %67 = ttg.convert_layout %25 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> %68 = tt.load %65, %66, %67 : tensor<16x16x!tt.ptr, #blocked4> - %69 = triton_gpu.convert_layout %68 : tensor<16x16xf16, #blocked4> -> tensor<16x16xf16, #blocked1> + %69 = ttg.convert_layout %68 : tensor<16x16xf16, #blocked4> -> tensor<16x16xf16, #blocked1> %70 = arith.addi %28, %arg17 : i32 %71 = tt.splat %70 : i32 -> tensor<16xi32, #blocked> %72 = arith.addi %71, %7 : tensor<16xi32, #blocked> %73 = tt.addptr %29, %72 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - %74 = triton_gpu.convert_layout %73 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> - %75 = triton_gpu.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> - %76 = triton_gpu.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %74 = ttg.convert_layout %73 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %75 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %76 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> %77 = tt.load %74, %75, %76 : tensor<16x!tt.ptr, #blocked> %78 = arith.addi %33, %arg17 : i32 %79 = tt.splat %78 : i32 -> tensor<16xi32, #blocked> %80 = arith.addi %79, %7 : tensor<16xi32, #blocked> %81 = tt.addptr %34, %80 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - %82 = triton_gpu.convert_layout %81 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> - %83 = triton_gpu.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> - %84 = triton_gpu.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %82 = ttg.convert_layout %81 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %83 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %84 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> %85 = tt.load %82, %83, %84 : tensor<16x!tt.ptr, #blocked> %86 = arith.cmpf "ogt", %arg20, %85 : tensor<16xf32, #blocked> %87 = arith.select %86, %arg20, %85 : tensor<16xi1, #blocked>, tensor<16xf32, #blocked> @@ -1385,14 +1385,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %95 = arith.divf %91, %94 : tensor<16xf32, #blocked> %96 = arith.divf %arg19, %94 : tensor<16xf32, #blocked> %97 = arith.mulf %96, %89 : tensor<16xf32, #blocked> - %98 = triton_gpu.convert_layout %97 : tensor<16xf32, #blocked> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> - %100 = triton_gpu.convert_layout %99 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %98 = ttg.convert_layout %97 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %100 = ttg.convert_layout %99 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> %101 = tt.broadcast %100 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> %102 = arith.mulf %arg18, %101 : tensor<16x16xf32, #blocked1> - %103 = triton_gpu.convert_layout %95 : tensor<16xf32, #blocked> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> - %105 = triton_gpu.convert_layout %104 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %103 = ttg.convert_layout %95 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %105 = ttg.convert_layout %104 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> %106 = tt.broadcast %105 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> %107 = arith.extf %69 : tensor<16x16xf16, #blocked1> to tensor<16x16xf32, #blocked1> %108 = arith.mulf %107, %106 : tensor<16x16xf32, #blocked1> @@ -1402,16 +1402,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %36 = arith.muli %2, %arg14 : i32 %37 = arith.muli %3, %arg15 : i32 %38 = arith.addi %36, %37 : i32 - %39 = triton_gpu.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> - %41 = triton_gpu.convert_layout %40 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %39 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %41 = ttg.convert_layout %40 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> %42 = tt.splat %arg16 : i32 -> tensor<16x1xi32, #blocked1> %43 = arith.muli %41, %42 : tensor<16x1xi32, #blocked1> - %44 = triton_gpu.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %45 = tt.expand_dims %44 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %44 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %45 = tt.expand_dims %44 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> %46 = tt.broadcast %43 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> %47 = tt.broadcast %45 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> - %48 = triton_gpu.convert_layout %47 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %48 = ttg.convert_layout %47 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> %49 = arith.addi %46, %48 : tensor<16x16xi32, #blocked1> %50 = tt.splat %38 : i32 -> tensor<16x16xi32, #blocked1> %51 = arith.addi %50, %49 : tensor<16x16xi32, #blocked1> @@ -1420,9 +1420,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %54 = arith.cmpi "slt", %41, %cst_3 : tensor<16x1xi32, #blocked1> %55 = tt.broadcast %54 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> %56 = arith.truncf %35#0 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> - %57 = triton_gpu.convert_layout %53 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> - %58 = triton_gpu.convert_layout %56 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> - %59 = triton_gpu.convert_layout %55 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + %57 = ttg.convert_layout %53 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %58 = ttg.convert_layout %56 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %59 = ttg.convert_layout %55 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> tt.store %57, %58, %59 : tensor<16x16x!tt.ptr, #blocked4> tt.return } @@ -1432,15 +1432,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // Check if MoveConvertOutOfLoop hangs because of adding additional conversions // CHECK-LABEL: @loop_print -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @loop_print(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c31_i32 = arith.constant 31 : i32 @@ -1450,25 +1450,25 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %cst_0 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> %cst_1 = arith.constant 0.000000e+00 : f32 %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2> - %1 = triton_gpu.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %3 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> %4 = arith.muli %2, %3 : tensor<128x1xi32, #blocked1> %5 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked2> - %6 = triton_gpu.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %6 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> %8 = tt.broadcast %4 : tensor<128x1xi32, #blocked1> -> tensor<128x32xi32, #blocked1> %9 = tt.broadcast %7 : tensor<1x32xi32, #blocked3> -> tensor<128x32xi32, #blocked3> - %10 = triton_gpu.convert_layout %9 : tensor<128x32xi32, #blocked3> -> tensor<128x32xi32, #blocked1> + %10 = ttg.convert_layout %9 : tensor<128x32xi32, #blocked3> -> tensor<128x32xi32, #blocked1> %11 = arith.addi %8, %10 : tensor<128x32xi32, #blocked1> - %12 = triton_gpu.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> - %14 = triton_gpu.convert_layout %13 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked> - %15 = triton_gpu.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %12 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %14 = ttg.convert_layout %13 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked> + %15 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> %17 = tt.broadcast %14 : tensor<32x1xi32, #blocked> -> tensor<32x128xi32, #blocked> %18 = tt.broadcast %16 : tensor<1x128xi32, #blocked3> -> tensor<32x128xi32, #blocked3> - %19 = triton_gpu.convert_layout %18 : tensor<32x128xi32, #blocked3> -> tensor<32x128xi32, #blocked> + %19 = ttg.convert_layout %18 : tensor<32x128xi32, #blocked3> -> tensor<32x128xi32, #blocked> %20 = arith.addi %17, %19 : tensor<32x128xi32, #blocked> %21 = arith.addi %arg5, %c31_i32 : i32 %22 = arith.divsi %21, %c32_i32 : i32 @@ -1477,19 +1477,19 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %25:3 = scf.for %arg7 = %c0_i32 to %22 step %c1_i32 iter_args(%arg8 = %cst_1, %arg9 = %11, %arg10 = %20) -> (f32, tensor<128x32xi32, #blocked1>, tensor<32x128xi32, #blocked>) : i32 { tt.print "a_offsets: " { hex = false, isSigned = array } : %arg9 : tensor<128x32xi32, #blocked1> %27 = tt.addptr %23, %arg9 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %28 = triton_gpu.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> + %28 = ttg.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> %29 = tt.load %28 : tensor<128x32x!tt.ptr, #blocked4> - %30 = triton_gpu.convert_layout %29 : tensor<128x32xf16, #blocked4> -> tensor<128x32xf16, #blocked1> + %30 = ttg.convert_layout %29 : tensor<128x32xf16, #blocked4> -> tensor<128x32xf16, #blocked1> %31 = tt.addptr %24, %arg10 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %32 = triton_gpu.convert_layout %31 : tensor<32x128x!tt.ptr, #blocked> -> tensor<32x128x!tt.ptr, #blocked5> + %32 = ttg.convert_layout %31 : tensor<32x128x!tt.ptr, #blocked> -> tensor<32x128x!tt.ptr, #blocked5> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked5> - %34 = triton_gpu.convert_layout %33 : tensor<32x128xf16, #blocked5> -> tensor<32x128xf16, #blocked> + %34 = ttg.convert_layout %33 : tensor<32x128xf16, #blocked5> -> tensor<32x128xf16, #blocked> %35 = "tt.reduce"(%30) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 tt.reduce.return %46 : f16 - }) : (tensor<128x32xf16, #blocked1>) -> tensor<32xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %36 = triton_gpu.convert_layout %35 : tensor<32xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<32xf16, #blocked2> + }) : (tensor<128x32xf16, #blocked1>) -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = ttg.convert_layout %35 : tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<32xf16, #blocked2> %37 = "tt.reduce"(%36) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 @@ -1499,8 +1499,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 tt.reduce.return %46 : f16 - }) : (tensor<32x128xf16, #blocked>) -> tensor<128xf16, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %39 = triton_gpu.convert_layout %38 : tensor<128xf16, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<128xf16, #blocked2> + }) : (tensor<32x128xf16, #blocked>) -> tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> + %39 = ttg.convert_layout %38 : tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<128xf16, #blocked2> %40 = "tt.reduce"(%39) <{axis = 0 : i32}> ({ ^bb0(%arg11: f16, %arg12: f16): %46 = arith.addf %arg11, %arg12 : f16 @@ -1525,50 +1525,51 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: reduce_cvt3 // CHECK: tt.dot // CHECK-NEXT: tt.reduce -// CHECK: triton_gpu.convert_layout -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked> %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked1> - %1 = triton_gpu.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xi32, #blocked2> - %3 = triton_gpu.convert_layout %2 : tensor<32x1xi32, #blocked2> -> tensor<32x1xi32, #blocked> + %1 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<32x1xi32, #blocked2> -> tensor<32x1xi32, #blocked> %4 = arith.muli %3, %cst_0 : tensor<32x1xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %7 = triton_gpu.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %7 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked3> -> tensor<32x32xi32, #blocked3> - %11 = triton_gpu.convert_layout %10 : tensor<32x32xi32, #blocked3> -> tensor<32x32xi32, #blocked> + %11 = ttg.convert_layout %10 : tensor<32x32xi32, #blocked3> -> tensor<32x32xi32, #blocked> %12 = tt.addptr %9, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %13 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %14 = tt.addptr %13, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> %16 = tt.addptr %15, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %17 = triton_gpu.convert_layout %12 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %17 = ttg.convert_layout %12 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> %18 = tt.load %17 : tensor<32x32x!tt.ptr, #blocked4> - %19 = triton_gpu.convert_layout %18 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> - %20 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %19 = ttg.convert_layout %18 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %20 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> %21 = tt.load %20 : tensor<32x32x!tt.ptr, #blocked4> - %22 = triton_gpu.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> - %23 = triton_gpu.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<32x32xf16, #shared> - %24 = tt.trans %23 {order=array} : !tt.memdesc<32x32xf16, #shared> -> !tt.memdesc<32x32xf16, #shared1> - %25 = triton_gpu.local_load %24 : !tt.memdesc<32x32xf16, #shared1> -> tensor<32x32xf16, #blocked> - %26 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> - %27 = triton_gpu.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> - %28 = triton_gpu.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> - %29 = tt.dot %26, %27, %28 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> - %30 = triton_gpu.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> + %22 = ttg.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %23 = ttg.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared, #smem> + %24 = ttg.memdesc_trans %23 {order=array} : !ttg.memdesc<32x32xf16, #shared, #smem> -> !ttg.memdesc<32x32xf16, #shared1, #smem> + %25 = ttg.local_load %24 : !ttg.memdesc<32x32xf16, #shared1, #smem> -> tensor<32x32xf16, #blocked> + %26 = ttg.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> + %27 = ttg.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> + %28 = ttg.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %30 = ttg.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): %37 = arith.cmpf "oeq", %arg3, %arg5 : f32 @@ -1579,12 +1580,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %42 = arith.select %41, %arg3, %arg5 : f32 %43 = arith.select %41, %arg4, %arg6 : i32 tt.reduce.return %42, %43 : f32, i32 - }) : (tensor<32x32xf32, #blocked>, tensor<32x32xi32, #blocked>) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) - %32 = triton_gpu.convert_layout %31#1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #blocked1> + }) : (tensor<32x32xf32, #blocked>, tensor<32x32xi32, #blocked>) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + %32 = ttg.convert_layout %31#1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #blocked1> %33 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #blocked1> %34 = tt.addptr %33, %0 : tensor<32x!tt.ptr, #blocked1>, tensor<32xi32, #blocked1> - %35 = triton_gpu.convert_layout %34 : tensor<32x!tt.ptr, #blocked1> -> tensor<32x!tt.ptr, #blocked1> - %36 = triton_gpu.convert_layout %32 : tensor<32xi32, #blocked1> -> tensor<32xi32, #blocked1> + %35 = ttg.convert_layout %34 : tensor<32x!tt.ptr, #blocked1> -> tensor<32x!tt.ptr, #blocked1> + %36 = ttg.convert_layout %32 : tensor<32xi32, #blocked1> -> tensor<32xi32, #blocked1> tt.store %35, %36 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -1594,20 +1595,20 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // Check that we don't have extra convert for flash attention IR. -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked3a = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}> -#blocked4a = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1]}> -#blocked6a = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked6 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked7 = #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2]}> -#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}> -#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}> +#blocked4a = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1]}> +#blocked6a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked7 = #ttg.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2]}> +#blocked8 = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}> +#blocked9 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @attention_fw(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %c0_i64 = arith.constant 0 : i64 %c64_i64 = arith.constant 64 : i64 @@ -1641,58 +1642,58 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> %20 = arith.extsi %19 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> %21 = arith.addi %18, %20 : tensor<128xi64, #blocked3a> - %22 = triton_gpu.convert_layout %21 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> - %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %22 = ttg.convert_layout %21 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> %24 = tt.splat %6 : i64 -> tensor<128x1xi64, #blocked4a> %25 = arith.muli %23, %24 : tensor<128x1xi64, #blocked4a> %26 = tt.broadcast %25 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %27 = triton_gpu.convert_layout %26 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %27 = ttg.convert_layout %26 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %28 = tt.addptr %17, %27 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %29 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %30 = arith.extsi %29 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %31 = triton_gpu.convert_layout %30 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> - %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %31 = ttg.convert_layout %30 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> %33 = tt.broadcast %32 : tensor<1x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %34 = triton_gpu.convert_layout %33 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %34 = ttg.convert_layout %33 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %35 = tt.addptr %28, %34 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %36 = tt.load %35 : tensor<128x64x!tt.ptr, #blocked3> - %37 = triton_gpu.convert_layout %36 : tensor<128x64xf16, #blocked3> -> tensor<128x64xf16, #blocked2> + %37 = ttg.convert_layout %36 : tensor<128x64xf16, #blocked3> -> tensor<128x64xf16, #blocked2> %38 = tt.splat %16 : f32 -> tensor<128x64xf32, #blocked2> %39 = arith.extf %37 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %40 = arith.mulf %39, %38 : tensor<128x64xf32, #blocked2> %41 = arith.truncf %40 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: scf.for -// CHECK-NOT: triton_gpu.convert_layout -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout // CHECK: tt.dot -// CHECK-NOT: triton_gpu.convert_layout -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK: triton_gpu.convert_layout %{{.*}} #triton_gpu.dot_op -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout // CHECK: tt.dot // CHECK: scf.yield %42:5 = scf.for %arg22 = %c0_i32 to %9 step %c64_i32 iter_args(%arg23 = %cst_2, %arg24 = %cst_1, %arg25 = %cst_0, %arg26 = %c0_i64, %arg27 = %c0_i64) -> (tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64) : i32 { %78 = tt.splat %8 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked6> %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> %80 = arith.extsi %79 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> - %81 = triton_gpu.convert_layout %80 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %82 = tt.expand_dims %81 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<64x1xi64, #blocked6> + %81 = ttg.convert_layout %80 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> -> tensor<64x1xi64, #blocked6> %83 = tt.broadcast %82 : tensor<64x1xi64, #blocked6> -> tensor<64x64xi64, #blocked6> - %84 = triton_gpu.convert_layout %83 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %84 = ttg.convert_layout %83 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> %85 = tt.addptr %78, %84 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> %86 = tt.splat %arg26 : i64 -> tensor<64xi64, #blocked6a> %87 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> %88 = arith.extsi %87 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> %89 = arith.addi %86, %88 : tensor<64xi64, #blocked6a> - %90 = triton_gpu.convert_layout %89 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> - %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %90 = ttg.convert_layout %89 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> %92 = tt.splat %10 : i64 -> tensor<1x64xi64, #blocked6> %93 = arith.muli %91, %92 : tensor<1x64xi64, #blocked6> %94 = tt.broadcast %93 : tensor<1x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> - %95 = triton_gpu.convert_layout %94 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %95 = ttg.convert_layout %94 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> %96 = tt.addptr %85, %95 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> %97 = tt.load %96 : tensor<64x64x!tt.ptr, #blocked6> %98 = tt.splat %11 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked3> @@ -1700,69 +1701,69 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %100 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %101 = arith.extsi %100 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> %102 = arith.addi %99, %101 : tensor<64xi64, #blocked3a> - %103 = triton_gpu.convert_layout %102 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi64, #blocked3> + %103 = ttg.convert_layout %102 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi64, #blocked3> %105 = tt.splat %12 : i64 -> tensor<64x1xi64, #blocked3> %106 = arith.muli %104, %105 : tensor<64x1xi64, #blocked3> %107 = tt.broadcast %106 : tensor<64x1xi64, #blocked3> -> tensor<64x64xi64, #blocked3> - %108 = triton_gpu.convert_layout %107 : tensor<64x64xi64, #blocked3> -> tensor<64x64xi64, #blocked3> + %108 = ttg.convert_layout %107 : tensor<64x64xi64, #blocked3> -> tensor<64x64xi64, #blocked3> %109 = tt.addptr %98, %108 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> %110 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %111 = arith.extsi %110 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %112 = triton_gpu.convert_layout %111 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> - %113 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %112 = ttg.convert_layout %111 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %113 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> %114 = tt.broadcast %113 : tensor<1x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked4a> - %115 = triton_gpu.convert_layout %114 : tensor<64x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked3> + %115 = ttg.convert_layout %114 : tensor<64x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked3> %116 = tt.addptr %109, %115 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> %117 = tt.load %116 : tensor<64x64x!tt.ptr, #blocked3> - %118 = triton_gpu.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %119 = triton_gpu.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %120 = tt.dot %118, %119, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> - %121 = triton_gpu.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> + %118 = ttg.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %119 = ttg.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %120 = tt.dot %118, %119, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %121 = ttg.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): %153 = arith.maximumf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 - }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %124 = triton_gpu.convert_layout %123 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %124 = ttg.convert_layout %123 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> %125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1> %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> - %128 = triton_gpu.convert_layout %125 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %129 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %130 = triton_gpu.convert_layout %129 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %128 = ttg.convert_layout %125 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %129 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %130 = ttg.convert_layout %129 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %131 = tt.broadcast %130 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %132 = arith.subf %122, %131 : tensor<128x64xf32, #blocked2> %133 = tt.extern_elementwise %132 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> %134 = arith.mulf %arg24, %cst_1 : tensor<128xf32, #blocked1> %135 = arith.addf %134, %127 : tensor<128xf32, #blocked1> - %136 = triton_gpu.convert_layout %135 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %137 = tt.expand_dims %136 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %138 = triton_gpu.convert_layout %137 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %136 = ttg.convert_layout %135 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %137 = tt.expand_dims %136 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %138 = ttg.convert_layout %137 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %139 = tt.broadcast %138 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %140 = arith.mulf %arg23, %139 : tensor<128x64xf32, #blocked2> %141 = arith.truncf %133 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> - %142 = triton_gpu.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %143 = triton_gpu.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %144 = triton_gpu.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> - %145 = tt.dot %142, %143, %144 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> - %146 = triton_gpu.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> + %142 = ttg.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %143 = ttg.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %144 = ttg.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %146 = ttg.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): %153 = arith.addf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 - }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %149 = triton_gpu.convert_layout %148 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %149 = ttg.convert_layout %148 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> %150 = arith.addf %147, %149 : tensor<128xf32, #blocked1> %151 = arith.addi %arg26, %c64_i64 : i64 %152 = arith.addi %arg27, %c64_i64 : i64 scf.yield %146, %150, %125, %151, %152 : tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64 } - %43 = triton_gpu.convert_layout %42#1 : tensor<128xf32, #blocked1> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> - %45 = triton_gpu.convert_layout %44 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %43 = ttg.convert_layout %42#1 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %45 = ttg.convert_layout %44 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> %46 = tt.broadcast %45 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> %47 = arith.divf %42#0, %46 : tensor<128x64xf32, #blocked2> %48 = arith.muli %1, %arg20 : i32 @@ -1776,25 +1777,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %55 = arith.extsi %arg17 : i32 to i64 %56 = arith.extsi %5 : i32 to i64 %57 = arith.truncf %47 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> - %58 = triton_gpu.convert_layout %57 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #blocked3> + %58 = ttg.convert_layout %57 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #blocked3> %59 = tt.splat %54 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked3> %60 = tt.splat %56 : i64 -> tensor<128xi64, #blocked3a> %61 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> %62 = arith.extsi %61 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> %63 = arith.addi %60, %62 : tensor<128xi64, #blocked3a> - %64 = triton_gpu.convert_layout %63 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> - %65 = tt.expand_dims %64 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %64 = ttg.convert_layout %63 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> %66 = tt.splat %55 : i64 -> tensor<128x1xi64, #blocked4a> %67 = arith.muli %65, %66 : tensor<128x1xi64, #blocked4a> %68 = tt.broadcast %67 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> - %69 = triton_gpu.convert_layout %68 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %69 = ttg.convert_layout %68 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> %70 = tt.addptr %59, %69 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> %71 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> %72 = arith.extsi %71 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> - %73 = triton_gpu.convert_layout %72 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> - %74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %73 = ttg.convert_layout %72 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> %75 = tt.broadcast %74 : tensor<1x64xi64, #blocked6> -> tensor<128x64xi64, #blocked6> - %76 = triton_gpu.convert_layout %75 : tensor<128x64xi64, #blocked6> -> tensor<128x64xi64, #blocked3> + %76 = ttg.convert_layout %75 : tensor<128x64xi64, #blocked6> -> tensor<128x64xi64, #blocked3> %77 = tt.addptr %70, %76 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> tt.store %77, %58 : tensor<128x64x!tt.ptr, #blocked3> tt.return @@ -1803,37 +1804,37 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK-LABEL: axis_mismatch -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { -tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> { // CHECK: %[[R:.+]] = "tt.reduce"(%0) <{axis = 1 : i32}> -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] // CHECK: tt.return %[[C]] %0 = tt.splat %arg0 : f32 -> tensor<1x16xf32, #blocked> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg9: f32, %arg10: f32): %60 = arith.addf %arg9, %arg10 : f32 tt.reduce.return %60 : f32 - }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = triton_gpu.convert_layout %1 : tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> - %3 = triton_gpu.convert_layout %2 : tensor<1xf32, #blocked1> -> tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - tt.return %3: tensor<1xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = ttg.convert_layout %1 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %3 = ttg.convert_layout %2 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return %3: tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: reduce_to_scalar -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i32) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked1> %3:2 = "tt.reduce"(%1, %2) <{axis = 0 : i32}> ({ ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): @@ -1852,9 +1853,9 @@ tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i3 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: whileop // CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> // CHECK: %[[W:.+]] = scf.while (%[[I:.+]] = %[[L]], %{{.*}} = %{{.*}}) : (tensor<1024xf32, #blocked>, i1) -> tensor<1024xf32, #blocked> { @@ -1867,17 +1868,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK: tt.store %{{.*}}, %[[W]] : tensor<1024x!tt.ptr, #blocked> tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> %2 = scf.while (%arg0 = %1, %arg1 = %cond) : (tensor<1024xf32, #blocked1>, i1) -> (tensor<1024xf32, #blocked1>) { scf.condition(%arg1) %arg0 : tensor<1024xf32, #blocked1> } do { ^bb0(%arg0: tensor<1024xf32, #blocked1>): - %4 = triton_gpu.convert_layout %arg0 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + %4 = ttg.convert_layout %arg0 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> %5 = arith.addf %4, %4 : tensor<1024xf32, #blocked> - %6 = triton_gpu.convert_layout %5 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %6 = ttg.convert_layout %5 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> scf.yield %6, %cond : tensor<1024xf32, #blocked1>, i1 } - %3 = triton_gpu.convert_layout %2 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> tt.store %ptr, %3 : tensor<1024x!tt.ptr, #blocked> tt.return } @@ -1898,7 +1899,7 @@ tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { // Check that we don't transform this loop into `yield %x` on the incorrect // theory that the yield is dead unless %x = %y. -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL @yield_outside_loop1 tt.func public @yield_outside_loop1(%arg0: i32, %arg1: i32) -> (i32) { @@ -1939,16 +1940,17 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { // Check that we handle corner cases when hoisting conversions on top of extf because conversion operations on a smaller type are faster. // For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice. // In this case we want to make sure we don't replace other uses of extf source. -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK: [[$BLOCKED:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -// CHECK: [[$MMA:#.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: [[$BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[$MMA:#.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> // CHECK-LABEL: @hoist_convert_above_extf_and_remat tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr) attributes {noinline = false} { @@ -1958,24 +1960,24 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %c64_i32 = arith.constant 64 : i32 %c256_i32 = arith.constant 256 : i32 %c0_i32 = arith.constant 0 : i32 - %cst_2 = arith.constant dense<1.000000e-03> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %cst_3 = arith.constant dense<2.560000e+02> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %cst_2 = arith.constant dense<1.000000e-03> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %cst_3 = arith.constant dense<2.560000e+02> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #blocked3> %c32_i32 = arith.constant 32 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked> %5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked> %6 = arith.muli %5, %cst : tensor<32x1xi32, #blocked> - %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> %11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> %14 = arith.muli %13, %cst_1 : tensor<256x1xi32, #blocked> %15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> %16 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> @@ -1993,29 +1995,29 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %67 = tt.load %66 : tensor<32x64x!tt.ptr, #blocked> %68 = tt.addptr %17, %65 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> %69 = tt.load %68 : tensor<256x64x!tt.ptr, #blocked> - %70 = triton_gpu.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared> - %71 = tt.trans %70 {order=array} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1> - %72 = triton_gpu.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> - %73 = triton_gpu.local_load %71 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> - %74 = triton_gpu.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> - %75 = triton_gpu.convert_layout %72 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %76 = triton_gpu.convert_layout %73 : tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %77 = tt.dot %75, %76, %74 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> - %78 = triton_gpu.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> + %70 = ttg.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> + %71 = ttg.memdesc_trans %70 {order=array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %72 = ttg.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %73 = ttg.local_load %71 : !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %74 = ttg.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> + %75 = ttg.convert_layout %72 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %76 = ttg.convert_layout %73 : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %77 = tt.dot %75, %76, %74 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> + %78 = ttg.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> scf.yield %78 : tensor<32x256xf32, #blocked3> } %19 = arith.truncf %18 : tensor<32x256xf32, #blocked3> to tensor<32x256xf16, #blocked3> - %20 = triton_gpu.convert_layout %19 : tensor<32x256xf16, #blocked3> -> tensor<32x256xf16, #blocked2> - %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> - %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<32x256xf16, #blocked3> -> tensor<32x256xf16, #blocked2> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %25 = tt.splat %arg2 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked2> %26 = tt.addptr %25, %23 : tensor<1x256x!tt.ptr, #blocked2>, tensor<1x256xi32, #blocked2> %27 = tt.load %26 : tensor<1x256x!tt.ptr, #blocked2> %28 = tt.broadcast %27 : tensor<1x256xf16, #blocked2> -> tensor<32x256xf16, #blocked2> %29 = arith.addf %20, %28 : tensor<32x256xf16, #blocked2> -// CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> +// CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> // CHECK: %[[B:.+]] = tt.broadcast %[[A]] // CHECK: %[[C:.+]] = arith.addf %[[B:.+]], {{.*}} // CHECK: arith.extf %[[C]] : tensor<32x256xf16, [[$MMA]]> to tensor<32x256xf32, [[$MMA]]> @@ -2024,28 +2026,28 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg7: f32, %arg8: f32): %58 = arith.addf %arg7, %arg8 : f32 tt.reduce.return %58 : f32 - }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %32 = arith.divf %31, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.divf %31, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> %33 = arith.mulf %30, %30 : tensor<32x256xf32, #blocked2> %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ ^bb0(%arg7: f32, %arg8: f32): %58 = arith.addf %arg7, %arg8 : f32 tt.reduce.return %58 : f32 - }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %35 = arith.divf %34, %cst_3 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %36 = arith.mulf %32, %32 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %37 = arith.subf %35, %36 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %38 = math.sqrt %37 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %39 = arith.addf %38, %cst_2 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %40 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> - %41 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %35 = arith.divf %34, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %36 = arith.mulf %32, %32 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %37 = arith.subf %35, %36 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %38 = math.sqrt %37 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %39 = arith.addf %38, %cst_2 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + %41 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> %42 = tt.broadcast %40 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> %43 = arith.subf %30, %42 : tensor<32x256xf32, #blocked2> %44 = tt.broadcast %41 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> %45 = arith.divf %43, %44 : tensor<32x256xf32, #blocked2> %46 = arith.truncf %45 : tensor<32x256xf32, #blocked2> to tensor<32x256xf16, #blocked2> - %47 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %47 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %49 = arith.muli %48, %cst_0 : tensor<32x1xi32, #blocked1> %50 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1> %51 = arith.addi %50, %49 : tensor<32x1xi32, #blocked1> @@ -2054,7 +2056,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %54 = arith.addi %52, %53 : tensor<32x256xi32, #blocked1> %55 = tt.splat %arg5 : !tt.ptr -> tensor<32x256x!tt.ptr, #blocked1> %56 = tt.addptr %55, %54 : tensor<32x256x!tt.ptr, #blocked1>, tensor<32x256xi32, #blocked1> - %57 = triton_gpu.convert_layout %46 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #blocked1> + %57 = ttg.convert_layout %46 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #blocked1> tt.store %56, %57 : tensor<32x256x!tt.ptr, #blocked1> tt.return } @@ -2062,60 +2064,60 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @backward_reduce_multiple_results -// CHECK-NOT: triton_gpu.convert_layout +// CHECK-NOT: ttg.convert_layout // CHECK: tt.return - tt.func public @backward_reduce_multiple_results() -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> { + tt.func public @backward_reduce_multiple_results() -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> { %cst = arith.constant dense<0xFFF0000000000000> : tensor<1x32xf64, #blocked1> - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2> - %2 = triton_gpu.convert_layout %1 : tensor<1x32xi32, #blocked2> -> tensor<1x32xi32, #blocked1> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1x32xi32, #blocked2> -> tensor<1x32xi32, #blocked1> %3:2 = "tt.reduce"(%cst, %2) <{axis = 1 : i32}> ({ ^bb0(%arg0: f64, %arg1: i32, %arg2: f64, %arg3: i32): %5 = arith.addi %arg1, %arg3 : i32 %6 = arith.addf %arg0, %arg2 : f64 tt.reduce.return %6, %5 : f64, i32 - }) : (tensor<1x32xf64, #blocked1>, tensor<1x32xi32, #blocked1>) -> (tensor<1xf64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) - %4 = triton_gpu.convert_layout %3#1 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - tt.return %4 : tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<1x32xf64, #blocked1>, tensor<1x32xi32, #blocked1>) -> (tensor<1xf64, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + tt.return %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> } } // end module // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @reshape_propagate tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { - // CHECK-NOT: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> - %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> + %c = ttg.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @reshape_sink_convert tt.func public @reshape_sink_convert(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked2> { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: tt.reshape - // CHECK: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } @@ -2123,18 +2125,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @permuting_reshape_propagate tt.func public @permuting_reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf16, #blocked2> { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: arith.truncf - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> - %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = ttg.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> } @@ -2142,24 +2144,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: scan_propagation tt.func @scan_propagation(%arg: tensor<1024xi32, #slice1dim1>) -> tensor<1024xi32, #slice1dim1> { - %1 = triton_gpu.convert_layout %arg : tensor<1024xi32, #slice1dim1> -> tensor<1024xi32, #blocked2> + %1 = ttg.convert_layout %arg : tensor<1024xi32, #slice1dim1> -> tensor<1024xi32, #blocked2> %2 = "tt.scan" (%1) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.scan.return %add : i32 }) {axis = 1 : i32, reverse = false} : (tensor<1024xi32, #blocked2>) -> tensor<1024xi32, #blocked2> - %3 = triton_gpu.convert_layout %2 : tensor<1024xi32, #blocked2> -> tensor<1024xi32, #slice1dim1> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #blocked2> -> tensor<1024xi32, #slice1dim1> // don't allow non blocked layout to be propagated to scan - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.scan - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.return tt.return %3: tensor<1024xi32, #slice1dim1> } @@ -2167,22 +2169,22 @@ tt.func @scan_propagation(%arg: tensor<1024xi32, #slice1dim1>) -> tensor<1024xi3 // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fw_propagate_for_op tt.func public @fw_propagate_for_op(%arg0: tensor<1024x4xi32, #blocked>, %arg1: tensor<1024x4x!tt.ptr, #blocked1>) { %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c1_i32 = arith.constant 1 : i32 - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: arith.muli // CHECK: scf.for // CHECK: scf.yield - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.store - %0 = triton_gpu.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> + %0 = ttg.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> %1 = arith.muli %0, %0 : tensor<1024x4xi32, #blocked1> %2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %1) -> (tensor<1024x4xi32, #blocked1>) : i32 { %3 = arith.addi %arg3, %arg3 : tensor<1024x4xi32, #blocked1> @@ -2195,16 +2197,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_through_if tt.func public @rematerialize_through_if(%arg0: i1, %arg1: f32) -> tensor<32xf32, #blocked> { // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: scf.if %arg0 -> (tensor<32xf32, #blocked>) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked1> %0 = tt.splat %arg1 : f32 -> tensor<32xf32, #blocked1> @@ -2215,30 +2217,30 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %2 = arith.addf %cst_0, %0 : tensor<32xf32, #blocked1> scf.yield %2 : tensor<32xf32, #blocked1> } - %4 = triton_gpu.convert_layout %3 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %4 = ttg.convert_layout %3 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %4 : tensor<32xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_if_inside_loop tt.func public @rematerialize_if_inside_loop() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) { // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: %[[for:[0-9]*]]:2 = scf.for {{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.if %{{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: tt.return %[[for]]#1, %[[for]]#0 %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> @@ -2251,25 +2253,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } else { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } scf.yield %3#0, %3#1 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: rematerialize_loop_arg tt.func public @rematerialize_loop_arg(%arg0: !tt.ptr) { - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 @@ -2278,14 +2280,14 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_2 = arith.constant dense<128> : tensor<128x64xi32, #blocked> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %0) -> (tensor<128x64x!tt.ptr, #blocked>) - // CHECK-NOT: triton_gpu.convert_layout + // CHECK-NOT: ttg.convert_layout // CHECK: scf.yield %{{.*}} : tensor<128x64x!tt.ptr, #blocked> %1 = scf.for %arg1 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg2 = %0) -> (tensor<128x64x!tt.ptr, #blocked>) : i32 { %2 = tt.addptr %arg2, %cst_1 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %3 = triton_gpu.convert_layout %2 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + %3 = ttg.convert_layout %2 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> tt.store %3, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> %4 = tt.addptr %arg2, %cst_2 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %5 = triton_gpu.convert_layout %4 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + %5 = ttg.convert_layout %4 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> tt.store %5, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> scf.yield %2 : tensor<128x64x!tt.ptr, #blocked> } @@ -2296,50 +2298,50 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: assertop // CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> // CHECK: tt.assert %[[L]] tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> - %1 = triton_gpu.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> + %1 = ttg.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> tt.assert %1, "cond must be true " : tensor<1024xi1, #blocked1> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @warp_group_dot_wait_propagate tt.func public @warp_group_dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { - // CHECK-NOT: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = triton_nvidia_gpu.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> - %c = triton_gpu.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = ttng.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> + %c = ttg.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> tt.return %c : tensor<16x2xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2,4], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4,2], threadsPerWarp = [2,16], warpsPerCTA = [1,1], order = [0,1]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4,2], threadsPerWarp = [2,16], warpsPerCTA = [1,1], order = [0,1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @trans_propagate tt.func public @trans_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<2x16xf32, #blocked2> { // CHECK: tt.trans - // CHECK: triton_gpu.convert_layout - %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> %b = tt.trans %a {order=array} : tensor<16x2xf32, #blocked1> -> tensor<2x16xf32, #blocked2> tt.return %b : tensor<2x16xf32, #blocked2> } @@ -2347,34 +2349,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // Verify that we don't hoist the convert on top of the broadcast. In general we should hoist the convert to reduce its cost // but because this would combine the 1st and 2nd convert and since the 1st convert is known to be a no-op this would // generate more expensive code. // CHECK-LABEL: @hoist_with_free_convert tt.func public @hoist_with_free_convert(%arg0: tensor<128x256xf32, #mma1>, %arg1: tensor<128x1xf32, #mma>) -> tensor<128x256xf32, #blocked> { - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.broadcast - // CHECK: triton_gpu.convert_layout + // CHECK: ttg.convert_layout // CHECK: tt.return - %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf32, #mma1> -> tensor<128x256xf32, #mma> + %0 = ttg.convert_layout %arg0 : tensor<128x256xf32, #mma1> -> tensor<128x256xf32, #mma> %1 = tt.broadcast %arg1 : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> %2 = arith.addf %0, %1 : tensor<128x256xf32, #mma> - %3 = triton_gpu.convert_layout %2 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> tt.return %3 : tensor<128x256xf32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @rematerialize_loop_arg tt.func public @rematerialize_loop_arg() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) { %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> @@ -2390,11 +2392,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.return %[[F]]#3, %[[F]]#1, %[[F]]#2 %1:3 = scf.for %arg0 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %cst) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) : i32 { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6, %4 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1, %1#2 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> } @@ -2402,22 +2404,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // Regression test: // Rematerialization of multiple loop-carried variables, where one is // rematerialized to the same layout by multiple users. // Previously this didn't interact correctly with the de-duplication mechanism. // CHECK-LABEL: @multi_rematerialize_loop_arg - tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) { %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %c2048_i32 = arith.constant 2048 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %cst_1 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> @@ -2425,59 +2427,59 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) - // CHECK: scf.yield {{.*}} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> // CHECK: } - // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %6 = tt.load %2 : tensor<64x64x!tt.ptr, #blocked2> - %7 = triton_gpu.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %8 = triton_gpu.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %7 = ttg.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %8 = ttg.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> %10 = tt.load %3 : tensor<128x64x!tt.ptr, #blocked> %11 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked> %12 = arith.cmpi eq, %10, %11 : tensor<128x64xi8, #blocked> - %13 = triton_gpu.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> + %13 = ttg.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> %14 = arith.select %13, %9, %cst_1 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> - %15 = triton_gpu.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %15 = ttg.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %34 = arith.maxnumf %arg6, %arg7 : f32 tt.reduce.return %34 : f32 - }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %19 = triton_gpu.convert_layout %18 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %20 = arith.select %18, %cst, %17 : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = ttg.convert_layout %18 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> + %20 = arith.select %18, %cst, %17 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> %22 = tt.broadcast %21 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> %23 = arith.select %22, %cst_2, %14 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> - %24 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> - %25 = arith.mulf %arg4, %cst : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %26 = triton_gpu.convert_layout %25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %25 = arith.mulf %arg4, %cst : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %26 = ttg.convert_layout %25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> %28 = tt.broadcast %27 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> %29 = arith.mulf %arg3, %28 : tensor<128x64xf32, #mma> - %30 = triton_gpu.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %31 = arith.mulf %arg4, %20 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %30 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %31 = arith.mulf %arg4, %20 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %32 = "tt.reduce"(%24) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %34 = arith.addf %arg6, %arg7 : f32 tt.reduce.return %34 : f32 - }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %33 = arith.addf %31, %32 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addf %31, %32 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - tt.return %5#1, %5#2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return %5#1, %5#2 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked7 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked7 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // Regression test: // The while loop use the result of the for loop as an argument. // When propagating the layout, we should only "forward" propagate the layout to the argument and the result of the while loop @@ -2495,25 +2497,25 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %74 = tt.load %1000 : tensor<256x64x!tt.ptr, #blocked2> %67:2 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg14 = %1001) -> (tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1>) : i32 { %76 = tt.load %arg14 : tensor<64x128x!tt.ptr, #blocked1> - %78 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> - %79 = triton_gpu.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> - %80 = triton_gpu.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> - %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> - %82 = triton_gpu.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> + %78 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> + %79 = ttg.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> + %80 = ttg.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> + %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> + %82 = ttg.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> scf.yield %82, %arg14 : tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1> } %68:2 = scf.while (%arg11 = %67#0, %arg12 = %c1_i32) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) { scf.condition(%c0_i1) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 } do { ^bb0(%arg11: tensor<256x128xf32, #blocked1>, %arg12: i32): - %80 = triton_gpu.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> + %80 = ttg.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> %81 = tt.load %80 : tensor<256x128x!tt.ptr, #blocked1> %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> %83 = arith.addi %arg12, %c1_i32 : i32 scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 } %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> - %71 = triton_gpu.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> + %71 = ttg.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> tt.store %1002, %71 : tensor<256x128x!tt.ptr, #blocked1> tt.return } @@ -2524,32 +2526,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that backward rematerialization bails out when the same tensor requires two different layouts // CHECK-LABEL: double_remat -// CHECK: %[[res:.*]] = triton_gpu.convert_layout +// CHECK: %[[res:.*]] = ttg.convert_layout // CHECK-NEXT: tt.return %[[res]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { tt.func public @double_remat() -> tensor<1x256xi32, #blocked> attributes {noinline = false} { %cst = arith.constant dense<0> : tensor<1x256xi32, #blocked1> - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> - %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> - %9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> + %9 = ttg.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> tt.return %9 : tensor<1x256xi32, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @if_condition_not_dead_inside_loop // CHECK: scf.if // CHECK-NOT: convert_layout @@ -2565,44 +2567,44 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } else { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> - %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> } %119 = arith.cmpi eq, %arg10, %arg0 : i32 scf.yield %3#0, %3#1, %119 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1 } - %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @dot_wait tt.func public @dot_wait(%arg0: tensor<64x64xf32, #mma>, %arg1: tensor<64x128xf32, #mma1>) -> (tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1>) { - %0:2 = triton_nvidia_gpu.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + %0:2 = ttng.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> tt.return %0#0, %0#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> - // CHECK: %[[W:.+]]:2 = triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[W:.+]]:2 = ttng.warp_group_dot_wait // CHECK: tt.return %[[W]]#0, %[[W]]#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @split_propagation // CHECK-SAME: (%[[ARG:.+]]: tensor<128x64x2xf32 // CHECK: %[[S:.+]], %{{.+}} = tt.split %[[ARG]] - // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[S]] + // CHECK: %[[C:.+]] = ttg.convert_layout %[[S]] // CHECK: tt.return %[[C]] tt.func public @split_propagation(%arg0: tensor<128x64x2xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { - %0 = triton_gpu.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> + %0 = ttg.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> %outLHS, %outRHS = tt.split %0 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked1> tt.return %outLHS : tensor<128x64xf32, #blocked1> } @@ -2610,14 +2612,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: matmul_add tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> @@ -2630,11 +2632,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> - %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + %t = ttg.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> // CHECK: %[[T0:.*]] = tt.dot // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> @@ -2644,8 +2646,85 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> } - // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + // CHECK: ttg.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> tt.return } } + +// ----- + +// Minimized reproducer for compiler crash during remove layouts conversions pass: +// If dot result transformed into tensor with shape smaller than one MFMA instruction size, it triggers various asserts. +// This is a smoke test that checks that compiler do not crash. +// +// CHECK-LABEL: small_tensor_mfma + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @small_tensor_mfma(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f32, %arg2: f32): + %3 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %3 : f32 + }) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> + %5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked> + %6 = ttg.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1> + %addr = tt.splat %arg0 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> + %8 = ttg.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked> + tt.store %addr, %8 : tensor<32x16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: lift_convert_to_local_load + // CHECK-NOT: convert_layout + // CHECK: tt.return + tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { + %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable> -> tensor<2x1x32x4x4xi8, #blocked> + %2 = tt.trans %1 {order = array} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1> + %3 = ttg.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> + tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// TODO(jeff): Support indices -> dst layout propagation to remove both +// layout conversions here. +tt.func @propagate_layout_gather(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) -> tensor<1024x256xf32, #blocked2> { + // CHECK-LABEL: propagate_layout_gather + + // XCHECK-NOT: convert_layout + %0 = ttg.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked1> + %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked1>) -> tensor<1024x256xf32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<1024x256xf32, #blocked1> -> tensor<1024x256xf32, #blocked2> + tt.return %2 : tensor<1024x256xf32, #blocked2> +} + +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 82fc1ddf7b65..17fe2bfaa6ed 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -1,30 +1,27 @@ // RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s -#Cv2 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#Av2k1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> -#Bv2k1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> -#Av2k2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> -#Bv2k2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> -#Av2k4 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> -#Bv2k4 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> -#Cv1 = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [4, 1]}> -#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}> -#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}> -#ALR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> -#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> +#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> +#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> +#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> +#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> +#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK: tt.func @push_elementwise // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> // CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] -// CHECK-SAME: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> +// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> // CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> tt.func @push_elementwise( %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -34,8 +31,8 @@ tt.func @push_elementwise( %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota = triton_gpu.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> - %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -43,7 +40,7 @@ tt.func @push_elementwise( // CHECK: tt.func @succeeds_if_arg_is_not_convert_layout // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] // CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] // CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] // CHECK: %[[C:.*]] = tt.dot %[[AF16]] @@ -53,18 +50,18 @@ tt.func @succeeds_if_arg_is_not_convert_layout( %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> - %dotai8 = triton_gpu.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> + %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> - %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } // CHECK: tt.func @push_inline_asm_op // CHECK: %[[ALOAD:.*]] = tt.load %arg0 -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] // CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] // CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] // CHECK: %[[C:.*]] = tt.dot %[[AF16]] @@ -76,7 +73,7 @@ tt.func @push_inline_asm_op( %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> - %dota_cvt = triton_gpu.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -85,23 +82,23 @@ tt.func @push_inline_asm_op( // ----- -#blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> // CHECK: tt.func @push_convert_both_operands // CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> // CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @push_convert_both_operands( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -110,9 +107,9 @@ tt.func @push_convert_both_operands( %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> - %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = triton_gpu.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -120,25 +117,25 @@ tt.func @push_convert_both_operands( // ----- -#blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { -// CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -// CHECK: #[[MMA:.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> // CHECK: tt.func @update_kwidth_slice -// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> // CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> -// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @update_kwidth_slice( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -149,9 +146,9 @@ tt.func @update_kwidth_slice( %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> - %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %bl = triton_gpu.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -159,55 +156,149 @@ tt.func @update_kwidth_slice( // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !ttg.memdesc<64x64xf16, #shared, #smem>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> - %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<64x64xf8E5M2, #shared, #smem> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !ttg.memdesc<64x64xf8E5M2, #shared, #smem>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !ttg.memdesc<128x64xf8E5M2, #shared1, #smem> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf8E5M2, #shared1, #smem> * !ttg.memdesc<64x64xf8E5M2, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @a_impl -// CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #triton_gpu.dot_op<{{.*}}>, tensor<128x128xf16, #triton_gpu.dot_op<{{.*}}> +// CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #ttg.dot_op<{{.*}}>, tensor<128x128xf16, #ttg.dot_op<{{.*}}> tt.func @a_impl(%pa: tensor<128x128x!tt.ptr, #blocked>) -> tensor<128x128xf32, #mma> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_3 = arith.constant dense<5> : tensor<128x1xi32, #blocked> %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked> %tl = tt.load %pa : tensor<128x128x!tt.ptr, #blocked> - %tr = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %te = tt.expand_dims %tr {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %tr = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %te = tt.expand_dims %tr {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %tc = arith.cmpi slt, %te, %cst_3 : tensor<128x1xi32, #blocked> %tb = tt.broadcast %tc : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> %ts = arith.select %tb, %tl, %cst_4 : tensor<128x128xi1, #blocked>, tensor<128x128xf16, #blocked> - %conv = triton_gpu.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %td = tt.dot %cst_0, %conv, %cst : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %conv = ttg.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %td = tt.dot %cst_0, %conv, %cst : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> tt.return %td : tensor<128x128xf32, #mma> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise_chained +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> + %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> + %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> + %dota = ttg.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mma_reorder_transpose +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttng.warp_group_dot + tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a = tt.trans %t {order = array} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mmav2_reorder_transpose +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttg.local_load +// CHECK: tt.dot + tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> + %cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index 9ed3646d92b2..5780bf672f9e 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -1,45 +1,47 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_like_fence tt.func public @matmul_like_fence(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked2>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> - %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> - // CHECK: triton_nvidia_gpu.fence_async_shared - %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + // CHECK: ttng.fence_async_shared + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fence_outside_loop tt.func public @fence_outside_loop(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c64_i32 = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> - %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> - // CHECK: triton_nvidia_gpu.fence_async_shared + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + // CHECK: ttng.fence_async_shared // CHECK: scf.for - // CHECK-NOT: triton_nvidia_gpu.fence_async_shared - // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: ttng.fence_async_shared + // CHECK: ttng.warp_group_dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/global_scratch_alloc.mlir b/test/TritonGPU/global_scratch_alloc.mlir new file mode 100644 index 000000000000..1c4d5bb2efc1 --- /dev/null +++ b/test/TritonGPU/global_scratch_alloc.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-global-scratch-memory-allocation | FileCheck %s + +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @test_alloc{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 + tt.func public @test_alloc() -> (!tt.ptr, !tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 + %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + tt.return %0, %1 : !tt.ptr, !tt.ptr + } +} + +// ----- + +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @helper1{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 128 : i32 + tt.func private @helper1() -> (!tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + tt.return %0 : !tt.ptr + } + +// CHECK: @test_function{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 + tt.func public @test_function() -> (!tt.ptr, !tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 + %1 = tt.call @helper1() : () -> (!tt.ptr) + tt.return %0, %1 : !tt.ptr, !tt.ptr + } +} diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index c8b3c2ef6b0b..8c90b013cc85 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -1,72 +1,78 @@ // RUN: triton-opt %s -split-input-file -verify-diagnostics -// expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}> +// expected-error@+2 {{ttg.dot_op opIdx paramenter can be 0 or 1, got: 2}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is not supported when the parent is a blocked layout}} -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #blocked, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is not supported when the parent is a blocked layout}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #blocked, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for MFMA parent}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma}> +// expected-error@+2 {{ttg.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma}> +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for MFMA parent}} +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mfma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}> + +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> + +// ----- +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} -#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> // ----- // expected-error@+1 {{major version must be in the [0, 3] range}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> // ----- // expected-error@+1 {{minor version must be 0}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> // ----- // expected-error@+1 {{(M, N) cases other than (32, 32) or (16, 16) unimplemented}} -#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index f9e265f3ee77..41ff5cc763a5 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -1,54 +1,86 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -tt.func public @subview_element_ty(%arg0: !tt.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{,}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> + tt.return +} + +// ----- + +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_memory_space(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{,}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared> -> !ttg.memdesc<8x16xf16> + tt.return +} + +// ----- + +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{element type}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !tt.memdesc<8x16xf32> -> !tt.memdesc<8x16xf16> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem> tt.return } // ----- -tt.func public @too_many_offsets(%arg0: !tt.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero, %zero] : !tt.memdesc<8x16xf32> -> !tt.memdesc + %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc tt.return } // ----- -tt.func public @too_few_offsets(%arg0: !tt.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = triton_gpu.memdesc_subview %arg0[%zero] : !tt.memdesc<8x16xf32> -> !tt.memdesc + %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc tt.return } // ----- -tt.func public @result_rank_too_large(%arg0: !tt.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result rank}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !tt.memdesc<8x16xf32> -> !tt.memdesc<3x8x16xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> tt.return } // ----- -tt.func public @result_dim_too_large(%arg0: !tt.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result shape}} - %a = triton_gpu.memdesc_subview %arg0[%zero, %zero] : !tt.memdesc<8x16xf32> -> !tt.memdesc<32xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<32xf32, #shared, #smem> tt.return } // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{element types of operands A and B must have same bit width}} %D = tt.dot %A, %B, %C : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> @@ -58,10 +90,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching encoding between A and B operands}} %D = tt.dot %A, %B, %C : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> @@ -71,10 +103,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32>) { // expected-error@+1 {{miss encoding of C operand}} %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32> @@ -84,10 +116,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 1 : i32} { +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching kWidth between A and B operands}} %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index b6610c0a663f..539ea317c20c 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -1,11 +1,12 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -20,39 +21,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for // CHECK: tt.dot // CHECK: tt.dot - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield - // CHECK: triton_gpu.async_wait {num = 0 : i32} + // CHECK: ttg.async_wait {num = 0 : i32} %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -61,14 +62,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 @@ -78,10 +80,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.get_program_id y : i32 %3 = tt.load %arg3 : !tt.ptr - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> @@ -92,10 +94,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %15 = arith.extsi %14 : i32 to i64 %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> @@ -105,8 +107,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> - %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> @@ -117,10 +119,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> - %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> @@ -139,15 +141,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> - %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> - %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem> -> !ttg.memdesc<64x32xf32, #shared1, #smem> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> - %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> scf.yield %79 : tensor<64x32xf32, #mma> } %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> @@ -155,7 +157,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> - %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> tt.return } @@ -163,34 +165,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @matmul_tma -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, #triton_gpu.shared_memory, mutable> -// CHECK-COUNT-3: triton_nvidia_gpu.init_barrier -// CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3xi64, #{{.+}}, #smem, mutable> +// CHECK-COUNT-3: ttng.init_barrier +// CHECK-COUNT-4: ttng.async_tma_copy_global_to_local // CHECK: scf.for -// CHECK: triton_nvidia_gpu.wait_barrier -// CHECK-NOT: triton_nvidia_gpu.wait_barrier -// CHECK-COUNT-2: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK-NOT: ttng.wait_barrier +// CHECK-COUNT-2: ttng.async_tma_copy_global_to_local // CHECK: scf.yield - tt.func public @matmul_tma(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x256xf32, #mma> { + tt.func public @matmul_tma(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<128x256xf32, #mma> { %c256_i32 = arith.constant 256 : i32 %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> - %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> - %5 = triton_nvidia_gpu.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 } diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 7fa7812c5a0b..49f67bd076d4 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,11 +1,12 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s - -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -20,37 +21,37 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK: scf.for + // CHECK: tt.load // CHECK: tt.dot // CHECK: tt.dot - // CHECK: tt.load - // CHECK: triton_gpu.local_store + // CHECK: ttg.local_store // CHECK: scf.yield %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -60,14 +61,15 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // ----- // CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de -// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> %c64_i32 = arith.constant 64 : i32 @@ -77,10 +79,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %1 = arith.muli %0, %c64_i32 : i32 %2 = tt.get_program_id y : i32 %3 = tt.load %arg3 : !tt.ptr - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> @@ -91,10 +93,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %15 = arith.extsi %14 : i32 to i64 %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> @@ -104,8 +106,8 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> - %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> @@ -116,10 +118,10 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> - %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> @@ -138,15 +140,15 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> - %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> - %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem, mutable> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem, mutable> -> !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> - %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> scf.yield %79 : tensor<64x32xf32, #mma> } %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> @@ -154,7 +156,7 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> - %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> tt.return } @@ -165,15 +167,15 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // CHECK-LABEL: tt.func public @add_barrier_kernel // CHECK: tt.load // CHECK: scf.for +// CHECK: tt.load // CHECK: gpu.barrier // CHECK: tt.store -// CHECK: tt.load // CHECK: scf.yield // CHECK: gpu.barrier // CHECK: tt.store -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @add_barrier_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -201,16 +203,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] -// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0] -// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] +// CHECK-NOT: #ttg.shared<{{.*}} order = [2, 0, 1] +// CHECK: #ttg.shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #ttg.shared<{{.*}} order = [2, 0, 1] // CHECK-LABEL: tt.func public @slowest_dim_is_batch -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> @@ -222,9 +224,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> - %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %43 = ttg.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = ttg.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> @@ -233,3 +235,102 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced +// CHECK-LABEL: loop_with_dot_and_transpose +// CHECK: ttg.local_alloc {{.*}}, mutable> +// CHECK: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1201", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @loop_with_dot_and_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32, %arg4: tensor<32x32x!tt.ptr, #blocked1>, %arg5: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { + %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> + %3 = ttg.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %4 = ttg.memdesc_trans %3 {order = array} : !ttg.memdesc<32x32xf32, #shared, #smem> -> !ttg.memdesc<32x32xf32, #shared1, #smem> + %5 = ttg.local_load %4 : !ttg.memdesc<32x32xf32, #shared1, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %6 = ttg.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> + scf.yield %7 : tensor<32x32xf32, #blocked> + } + tt.store %arg5, %0 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Check that the stream pipeliner updates atomic op in the k-loop correctly +// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw +// CHECK: scf.for +// CHECK: tt.atomic_rmw fadd, acq_rel, gpu +// CHECK: tt.dot +// CHECK: scf.yield + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant dense<32> : tensor<32x32xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked> + %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %13 = tt.splat %arg2 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked> + %19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked> + %20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked> + %21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked> + %23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked> + %24 = arith.addi %arg3, %c31_i32 : i32 + %25 = arith.divsi %24, %c32_i32 : i32 + %26 = arith.muli %arg4, %c32_i32 : i32 + %27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked> + %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { + %32 = tt.load %arg7 : tensor<32x32x!tt.ptr, #blocked> + %33 = tt.load %arg8 : tensor<32x32x!tt.ptr, #blocked> + %34 = ttg.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %35 = ttg.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %40 = ttg.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> + %41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> + scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> + } + %29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %30 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> + %31 = ttg.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> + tt.store %30, %29, %31 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir similarity index 58% rename from test/TritonGPU/pipeline-hopper-remove-wait.mlir rename to test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir index 74fd2e05551b..0846a44f6c1f 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir @@ -1,21 +1,22 @@ -// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_dependent_dot tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 - %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %cst_4 = arith.constant 1.44269502 : f32 @@ -34,25 +35,25 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %9 = arith.extsi %3 : i32 to i64 %10 = arith.extsi %c0_i32 : i32 to i64 %11 = arith.muli %0, %c128_i32 : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> - %15 = tt.splat %11 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %16 = tt.splat %11 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %17 = tt.splat %11 : i32 -> tensor<128xi32, #blocked1> - %18 = arith.addi %15, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %19 = arith.addi %16, %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %18 = arith.addi %15, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = arith.addi %16, %13 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> %20 = arith.addi %17, %14 : tensor<128xi32, #blocked1> %21 = arith.mulf %arg3, %cst_4 : f32 %22 = tt.addptr %arg0, %2 : !tt.ptr, i32 - %23 = tt.expand_dims %18 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %24 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %23 = tt.expand_dims %18 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %24 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> %25 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> %26 = arith.muli %23, %25 : tensor<128x1xi32, #blocked> %27 = tt.splat %22 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> %32 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> %33 = tt.addptr %31, %32 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> @@ -63,23 +64,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %38 = arith.truncf %37 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> %39 = arith.addi %0, %c1_i32 : i32 %40 = arith.muli %39, %c128_i32 : i32 - %41:7 = scf.for %arg22 = %c0_i32 to %40 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %7, %arg28 = %9, %arg29 = %10) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64) : i32 { + %41:7 = scf.for %arg22 = %c0_i32 to %40 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %7, %arg28 = %9, %arg29 = %10) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64) : i32 { %69 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> - %70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %71 = arith.extsi %70 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %72 = tt.splat %arg26 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = arith.addi %71, %72 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %74 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> + %70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %71 = arith.extsi %70 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %72 = tt.splat %arg26 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %73 = arith.addi %71, %72 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %74 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> %75 = tt.broadcast %74 : tensor<128x1xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %76 = tt.splat %c1_i64 : i64 -> tensor<128x64xi64, #blocked2> %77 = arith.muli %75, %76 : tensor<128x64xi64, #blocked2> %78 = tt.broadcast %77 : tensor<128x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %79 = tt.addptr %69, %78 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> - %80 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %81 = arith.extsi %80 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %82 = tt.splat %arg27 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %83 = arith.addi %81, %82 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> + %80 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %81 = arith.extsi %80 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %82 = tt.splat %arg27 : i64 -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.addi %81, %82 : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> %85 = tt.broadcast %84 : tensor<1x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> %86 = tt.splat %5 : i64 -> tensor<128x64xi64, #blocked2> %87 = arith.muli %85, %86 : tensor<128x64xi64, #blocked2> @@ -87,43 +88,43 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %89 = tt.addptr %79, %88 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> %90 = tt.load %89 : tensor<128x64x!tt.ptr, #blocked2> %91 = tt.splat %arg2 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked> - %92 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %93 = arith.extsi %92 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %94 = tt.splat %arg28 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %95 = arith.addi %93, %94 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi64, #blocked> + %92 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %93 = arith.extsi %92 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %94 = tt.splat %arg28 : i64 -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %95 = arith.addi %93, %94 : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi64, #blocked> %97 = tt.broadcast %96 : tensor<64x1xi64, #blocked> -> tensor<64x128xi64, #blocked> %98 = tt.splat %8 : i64 -> tensor<64x128xi64, #blocked> %99 = arith.muli %97, %98 : tensor<64x128xi64, #blocked> %100 = tt.broadcast %99 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %101 = tt.addptr %91, %100 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> - %102 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %103 = arith.extsi %102 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %104 = tt.splat %arg29 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %105 = arith.addi %103, %104 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %102 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %103 = arith.extsi %102 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %104 = tt.splat %arg29 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %105 = arith.addi %103, %104 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> %107 = tt.broadcast %106 : tensor<1x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %108 = tt.splat %c1_i64 : i64 -> tensor<64x128xi64, #blocked> %109 = arith.muli %107, %108 : tensor<64x128xi64, #blocked> %110 = tt.broadcast %109 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %111 = tt.addptr %101, %110 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> - %113 = triton_gpu.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> - %114 = triton_gpu.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> - %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %113 = ttg.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %114 = ttg.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %115 = ttng.warp_group_dot %113, %114, %cst :!ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %117 = ttg.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + %118 = ttg.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: ttng.warp_group_dot + // CHECK: ttng.warp_group_dot_wait {{.*}}, {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> - %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %119 = ttng.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<128x128xf32, #mma1> + %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %121 = arith.addf %120, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 %123 = arith.addi %arg26, %122 : i64 %124 = arith.extsi %c64_i32 : i32 to i64 @@ -132,30 +133,30 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %127 = arith.addi %arg28, %126 : i64 %128 = arith.extsi %c0_i32 : i32 to i64 %129 = arith.addi %arg29, %128 : i64 - scf.yield %119, %121, %arg25, %123, %125, %127, %129 : tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64 + scf.yield %119, %121, %arg25, %123, %125, %127, %129 : tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64 } %42 = arith.addi %3, %11 : i32 %43 = arith.extsi %arg17 : i32 to i64 %44 = arith.extsi %42 : i32 to i64 %45 = arith.extsi %c0_i32 : i32 to i64 %46 = arith.truncf %41#0 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> - %47 = triton_gpu.convert_layout %46 : tensor<128x128xf16, #mma1> -> tensor<128x128xf16, #blocked> + %47 = ttg.convert_layout %46 : tensor<128x128xf16, #mma1> -> tensor<128x128xf16, #blocked> %48 = tt.splat %arg5 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> - %49 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %50 = arith.extsi %49 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %51 = tt.splat %44 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %52 = arith.addi %50, %51 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> + %49 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %50 = arith.extsi %49 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %51 = tt.splat %44 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = arith.addi %50, %51 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> %54 = tt.broadcast %53 : tensor<128x1xi64, #blocked> -> tensor<128x128xi64, #blocked> %55 = tt.splat %43 : i64 -> tensor<128x128xi64, #blocked> %56 = arith.muli %54, %55 : tensor<128x128xi64, #blocked> %57 = tt.broadcast %56 : tensor<128x128xi64, #blocked> -> tensor<128x128xi64, #blocked> %58 = tt.addptr %48, %57 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi64, #blocked> - %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %60 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %61 = tt.splat %45 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %62 = arith.addi %60, %61 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %63 = tt.expand_dims %62 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %60 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %61 = tt.splat %45 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %62 = arith.addi %60, %61 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %63 = tt.expand_dims %62 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> %64 = tt.broadcast %63 : tensor<1x128xi64, #blocked> -> tensor<128x128xi64, #blocked> %65 = tt.splat %c1_i64 : i64 -> tensor<128x128xi64, #blocked> %66 = arith.muli %64, %65 : tensor<128x128xi64, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d391be688c23..138cebcf2a1b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -1,59 +1,60 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable, 2x128x32> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable, 2x32x128> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: %[[ASUB3:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -82,9 +83,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -98,61 +99,61 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK: scf.for -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK scf.yield -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -181,9 +182,9 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -198,39 +199,39 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { @@ -251,7 +252,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> @@ -261,7 +262,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> @@ -275,48 +276,48 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // TODO: MCast is not supported yet //// 4 warps, TMA Load //// matmul: 128x32 @ 32x128 -> 128x128 -//#C = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1]}> -//#SA = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset=true}> -//#SB = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset=true}> -//#BA = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -//#BB = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> +//#C = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1]}> +//#SA = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset=true}> +//#SB = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset=true}> +//#BA = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +//#BB = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> //// C-HECK: func @matmul_loop //// C-HECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 //// C-HECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 //// C-HECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 //// C-HECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 -//// C-HECK: %[[MBARRIER_AB:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} -//// C-HECK: %[[EMPTY_BARRIER_B:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 2 : i32} -//// C-HECK: %[[ABUFFER:.*]] = triton_gpu.alloc -//// C-HECK: %[[MBARRIER_AB0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c0_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB0]] -//// C-HECK: %[[A0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] -//// C-HECK: %[[BBUFFER:.*]] = triton_gpu.alloc -//// C-HECK: %[[EMPTY_BARRIER_B0:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][%c0_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B0]], %true -//// C-HECK: %[[B0BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] -//// C-HECK: %[[MBARRIER_AB1:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][%c1_i32] -//// C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[MBARRIER_AB1]] -//// C-HECK: %[[A1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] -//// C-HECK: %[[B1BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] -//// C-HECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -//// C-HECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] +//// C-HECK: %[[MBARRIER_AB:.*]] = ttng.alloc_mbarrier {count = 1 : i32} +//// C-HECK: %[[EMPTY_BARRIER_B:.*]] = ttng.alloc_mbarrier {count = 2 : i32} +//// C-HECK: %[[ABUFFER:.*]] = ttg.alloc +//// C-HECK: %[[MBARRIER_AB0:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][%c0_i32] +//// C-HECK: ttng.mbarrier_arrive %[[MBARRIER_AB0]] +//// C-HECK: %[[A0BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[BBUFFER:.*]] = ttg.alloc +//// C-HECK: %[[EMPTY_BARRIER_B0:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][%c0_i32] +//// C-HECK: ttng.mbarrier_wait %[[EMPTY_BARRIER_B0]], %true +//// C-HECK: %[[B0BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[MBARRIER_AB0]] +//// C-HECK: %[[MBARRIER_AB1:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][%c1_i32] +//// C-HECK: ttng.mbarrier_arrive %[[MBARRIER_AB1]] +//// C-HECK: %[[A1BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[B1BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[MBARRIER_AB1]] +//// C-HECK: %[[A0:.*]] = ttg.extract_slice %[[A1BUFFER]][0, 0, 0] +//// C-HECK: %[[B0:.*]] = ttg.extract_slice %[[B1BUFFER]][0, 0, 0] //// C-HECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// // C-HECK: %[[MBARRIER_AB_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} -// // C-HECK: triton_nvidia_gpu.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// // C-HECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} -// // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] -// // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] -// // C-HECK: %[[NEXT_A_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] -// // C-HECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// // C-HECK: %[[EMPTY_BARRIER_B_ITER_WAIT:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] -// // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[EMPTY_BARRIER_B_ITER_WAIT]], {{.*}} -// // C-HECK: %[[NEXT_B_BUFFER:.*]] = triton_nvidia_gpu.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] -// // C-HECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: %[[MBARRIER_AB_ITER:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: ttng.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} +// // C-HECK: ttng.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// // C-HECK: ttng.warp_group_dot_wait {{.*}} +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: ttng.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] +// // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = ttng.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] +// // C-HECK: %[[NEXT_A_BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_A:.*]] = ttg.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] +// // C-HECK: %[[EMPTY_BARRIER_B_ITER_WAIT:.*]] = ttng.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] +// // C-HECK: ttng.mbarrier_wait %[[EMPTY_BARRIER_B_ITER_WAIT]], {{.*}} +// // C-HECK: %[[NEXT_B_BUFFER:.*]] = ttng.insert_slice_tma {{.*}}, {{.*}}, {{.*}}, %[[MBARRIER_AB_NEXT_ITER]] +// // C-HECK: %[[NEXT_B:.*]] = ttg.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] // // C-HECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} -//module attributes {"triton_gpu.num-ctas" = 2 : i32, "triton_gpu.num-warps" = 4 : i32} { +//module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32} { // tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // %A : !tt.ptr {tt.divisibility = 16 : i32}, // %B : !tt.ptr {tt.divisibility = 16 : i32}) -> (!tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C>) { @@ -333,9 +334,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA, #triton_gpu.shared_memory> -// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB, #triton_gpu.shared_memory> -// %c = triton_nvidia_gpu.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %sa = ttg.local_alloc %a : (tensor<128x32xf16, #BA>) -> !ttg.memdesc<128x32xf16, #SA, #smem> +// %sb = ttg.local_alloc %b : (tensor<32x128xf16, #BB>) -> !ttg.memdesc<32x128xf16, #SB, #smem> +// %c = ttng.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> @@ -348,13 +349,14 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_chained_single_load tt.func @dot_chained_single_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> @@ -370,36 +372,36 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %19, %20, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %23 = ttg.memdesc_trans %20 {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + %24 = ttg.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = ttng.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -419,35 +421,35 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for - // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 1 : i32} - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_commit_group + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group // CHECK: scf.if - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} // CHECK: arith.mulf // CHECK: scf.yield // CHECK: scf.yield - // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -465,13 +467,14 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_accumulator_escape tt.func @two_accumulator_escape(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> @@ -481,45 +484,45 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc - // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + // CHECK: %[[ALLOC1:.+]] = ttg.local_alloc + // CHECK: %[[ALLOC2:.+]] = ttg.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: %[[TRANS:.+]] = tt.trans{{.*}} : !tt.memdesc - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} %[[TRANS]] - // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot{{.*}} + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: %[[TRANS:.+]] = ttg.memdesc_trans{{.*}} : !ttg.memdesc + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot{{.*}} %[[TRANS]] + // CHECK: ttng.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} // CHECK: scf.yield - // CHECK: %{{.*}}:2 = triton_nvidia_gpu.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + // CHECK: %{{.*}}:2 = ttng.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %21 = ttng.warp_group_dot %19, %20, %arg6 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + %25 = ttng.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -529,46 +532,47 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory // Make sure that if one of the load dot operand is not pipelined (and therefore not double buffered) we won't use // async dot. -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: no_wgmma_pipeline tt.func public @no_wgmma_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 - %cst_0 = arith.constant dense<512> : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %cst_1 = arith.constant dense<512> : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %cst_0 = arith.constant dense<512> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_1 = arith.constant dense<512> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %cst_2 = arith.constant dense<512> : tensor<128x1xi32, #blocked> %cst_3 = arith.constant dense<512> : tensor<128x1xi32, #blocked1> %cst_4 = arith.constant dense<512> : tensor<64x1xi32, #blocked1> %cst_5 = arith.constant dense<32768> : tensor<64x256xi32, #blocked1> %cst_6 = arith.constant dense<64> : tensor<128x64xi32, #blocked> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %1 = arith.remsi %0, %cst_0 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %3 = arith.remsi %2, %cst_1 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = arith.remsi %0, %cst_0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %3 = arith.remsi %2, %cst_1 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> %5 = arith.muli %4, %cst_2 : tensor<128x1xi32, #blocked> - %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> %8 = tt.broadcast %5 : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> %9 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> %10 = arith.addi %8, %9 : tensor<128x64xi32, #blocked> %11 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %12 = tt.addptr %11, %10 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> %15 = arith.muli %14, %cst_4 : tensor<64x1xi32, #blocked1> - %16 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %16 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %17 = tt.broadcast %15 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> %18 = tt.broadcast %16 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> %19 = arith.addi %17, %18 : tensor<64x256xi32, #blocked1> @@ -577,29 +581,29 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> - %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> - // CHECK: triton_gpu.local_alloc + %37 = ttg.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !ttg.memdesc<128x64xf8E5M2, #shared, #smem> + %38 = ttg.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> + // CHECK: ttg.local_alloc // CHECK: scf.for - // CHECK: triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait - %39 = triton_nvidia_gpu.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait + %39 = ttng.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E5M2, #shared, #smem> * !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> } %23 = arith.truncf %22#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %26 = arith.muli %25, %cst_3 : tensor<128x1xi32, #blocked1> %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %29 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %29 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> %30 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> %31 = tt.broadcast %29 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> %32 = tt.addptr %30, %31 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> %33 = tt.fp_to_fp %23 {rounding = 1 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf8E5M2, #mma> - %34 = triton_gpu.convert_layout %33 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked1> + %34 = ttg.convert_layout %33 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked1> tt.store %32, %34 : tensor<128x256x!tt.ptr, #blocked1> tt.return } @@ -608,23 +612,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // A dot can be properly async if all its uses follow a synchronous MMAv3 dot. -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: async_following_sync tt.func @async_following_sync(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { - %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst = arith.constant dense<64> : tensor<64x16xi32, #blocked> %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -643,49 +648,49 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait - // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT0:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait // CHECK-DAG-SAME: %[[DOT0]] // CHECK-DAG-SAME: %[[DOT1]] // CHECK-DAG-SAME: %[[PREV_DOT2]] // CHECK-SAME: {pendings = 0 : i32} - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot - // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: scf.yield %[[DOT2]] - // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %dot0 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %dot1 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = ttng.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -695,44 +700,48 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- // Test pipelining of experimental_descriptor_store -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store_pipeline - tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global - tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_multiple_store_pipeline - tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 - // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> // CHECK: scf.for scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 %2 = arith.divsi %arg2, %arg4 : i32 - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] - // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} - // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] - // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared - // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] - tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr, tensor<1xf32, #blocked> - tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } tt.return } @@ -741,28 +750,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> -#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: _kernel_matmul_dependency - tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { - %cst = arith.constant dense<0> : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> %cst_0 = arith.constant 1.000000e+00 : f32 %c8_i32 = arith.constant 8 : i32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked1> - %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) : i32 { + %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) : i32 { %3 = arith.addi %arg7, %c8_i32 : i32 %4 = arith.cmpi eq, %3, %c8_i32 : i32 - %5:2 = scf.if %4 -> (i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) { + %5:2 = scf.if %4 -> (i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) { %21 = arith.addi %arg8, %c8_i32 : i32 - scf.yield %21, %arg5 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %21, %arg5 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } else { - scf.yield %arg8, %arg10 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %arg8, %arg10 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } %6 = arith.cmpi eq, %3, %c8_i32 : i32 %7 = scf.if %6 -> (f32) { @@ -771,16 +781,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %21 = tt.load %arg4 : !tt.ptr scf.yield %21 : f32 } - %8 = tt.splat %3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %9 = arith.addi %8, %0 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %8 = tt.splat %3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %9 = arith.addi %8, %0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> %11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1> %12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> %13 = tt.load %arg0 : tensor<128x128x!tt.ptr, #blocked> - %14 = triton_gpu.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !tt.memdesc<128x128xf8E4M3FNUZ, #shared> + %14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> %15 = tt.load %12 : tensor<128x128x!tt.ptr, #blocked1> - %16 = triton_gpu.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !tt.memdesc<128x128xf8E4M3FNUZ, #shared1> - %17 = triton_nvidia_gpu.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x128xf8E4M3FNUZ, #shared> * !tt.memdesc<128x128xf8E4M3FNUZ, #shared1> -> tensor<128x128xf32, #mma> + %16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> + %17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> -> tensor<128x128xf32, #mma> %18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma> %19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma> %20 = scf.if %6 -> (tensor<128x128xf32, #mma>) { @@ -788,7 +798,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : } else { scf.yield %19 : tensor<128x128xf32, #mma> } - scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> } tt.return } @@ -797,13 +807,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // ----- // Pipeline the if ops at the beginning and the end of the loop -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: dot_prologue_epilogue // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { @@ -817,14 +828,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -849,9 +860,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> } %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> scf.yield %acc_zero : tensor<128x16xf32, #mma1> @@ -869,13 +880,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { @@ -890,14 +902,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -914,9 +926,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -931,3 +943,67 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %17#0 : tensor<128x16xf32, #mma1> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_lhs_registers + tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: ttg.local_load + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, + tensor<64x16x!tt.ptr, #blocked>) : i32 { + %a_block = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked1> + %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> + %a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma> + %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x16xf32, #mma> + } +} diff --git a/test/TritonGPU/loop-pipeline-indirect-load.mlir b/test/TritonGPU/loop-pipeline-indirect-load.mlir index 74794b9496b3..af260c65c87d 100644 --- a/test/TritonGPU/loop-pipeline-indirect-load.mlir +++ b/test/TritonGPU/loop-pipeline-indirect-load.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=2 -tritongpu-pipeline=num-stages=2 | FileCheck %s // CHECK-LABEL: @indirect_load_two_stages // CHECK: scf.for // CHECK: tt.dot @@ -6,11 +6,11 @@ // CHECK: async_copy_global_to_local // CHECK: async_copy_global_to_local -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @indirect_load_two_stages(%arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c16_i32 = arith.constant 16 : i32 @@ -22,68 +22,68 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %7 = tt.get_program_id x : i32 %8 = arith.muli %7, %c16_i32 : i32 - %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.splat %8 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %18 = arith.addi %15, %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> - %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> %34 = arith.extsi %arg12 : i32 to i64 %35 = arith.muli %2, %34 : i64 %36 = tt.addptr %arg2, %35 : !tt.ptr, i64 - %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> - %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %61 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> - %85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %85 = arith.extsi %22 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> %107 = tt.splat %36 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked3> %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> %101 = tt.splat %arg5 : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked1> %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { - %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %161 = tt.load %160 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %161 = tt.load %160 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr, #blocked1>, tensor<16x32xi64, #blocked1> %183 = tt.load %182 : tensor<16x32x!tt.ptr, #blocked1> %197 = arith.extsi %arg28 : i32 to i64 - %198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> + %198 = tt.splat %197 : i64 -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %199 = arith.addi %198, %85 : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi64, #blocked3> %209 = tt.load %204 : tensor<32x128x!tt.ptr, #blocked3> - %210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> - %211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> + %210 = ttg.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %211 = ttg.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> scf.yield %212 : tensor<16x128xf32, #blocked> } - %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> %116 = arith.extsi %arg17 : i32 to i64 %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> - %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> %124 = tt.splat %arg7 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked3> %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr, #blocked3>, tensor<16x128xi64, #blocked3> - %128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> + %128 = ttg.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> tt.store %125, %128 : tensor<16x128x!tt.ptr, #blocked3> tt.return } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 3d215a635da3..29d61e07a4e9 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,59 +1,61 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[ASUB1:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_A1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] -// CHECK-DAG: %[[BSUB1:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B1:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: %[[ASUB3:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[ASUB3]] -// CHECK: %[[BSUB3:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] // AMD-LABEL: tt.func @matmul_loop @@ -62,22 +64,22 @@ // AMD-DAG: %[[C0:.*]] = arith.constant 0 : index // AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index // AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) -// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}} -// AMD: %[[DOT_35:.*]] = tt.dot %[[LOCAL_LOAD_32]], %[[MULF_34]], %[[ARG8]] -// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// AMD: %[[ADDPTR_37:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_36]] -// AMD: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_37]] -// AMD: %[[ADDI_40:.*]] = arith.addi %[[ARG9]], %{{.*}} -// AMD: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} -// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_43]] -// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]] -// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]] +// AMD: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]] +// AMD: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG10]] +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]] +// AMD: %[[LOCAL_LOAD_39:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[MULF_40:.*]] = arith.mulf %[[LOCAL_LOAD_39]], %{{.*}} +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[MULF_40]], %[[ARG8]] +// AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] +// AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] // AMD: } // AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] // AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] @@ -86,8 +88,8 @@ // AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] // AMD: %[[DIVSI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]] // AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[DIVSI_26]], %{{.*}} -// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %{{.*}}#5 // AMD: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} // AMD: %[[IF_31:.*]] = scf.if %[[CMPI_27]] // AMD: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %{{.*}}#2 @@ -96,10 +98,38 @@ // AMD: scf.yield %{{.*}}#2 // AMD: } // AMD: %[[SELECT_32:.*]] = arith.select %[[CMPI_27]], %[[IF_31]], %{{.*}}#2 -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// Prefetch pipelining adds another stage in between global load and compute. +// This stage will local_store, then local_load, creating a prefetch from shared +// memory into a register buffer for compute. +// +// AMD_PREFETCH-LABEL: tt.func @matmul_loop +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -130,9 +160,9 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b__ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_ = triton_gpu.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -148,75 +178,79 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK: scf.for -// CHECK-DAG: %[[A0:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.local_load %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_A:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[NEXT_A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK scf.yield // AMD-LABEL: tt.func @matmul_loop_nested // AMD: scf.for -// AMD-COUNT-2: triton_gpu.local_alloc +// AMD-COUNT-2: ttg.local_alloc // AMD-COUNT-2: tt.load -// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: %[[FOR:.*]]:6 = scf.for -// AMD-COUNT-2: triton_gpu.local_load -// AMD: tt.dot // AMD-COUNT-2: tt.addptr -// AMD-COUNT-2: tt.load -// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] -// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: tt.load +// AMD: ttg.local_load +// AMD: tt.load +// AMD: ttg.local_load +// AMD: tt.dot +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: scf.yield -// AMD-COUNT-2: triton_gpu.local_load +// AMD-COUNT-2: ttg.local_load // AMD: %[[IF1:.*]] = scf.if // AMD: %[[DOT1:.*]] = tt.dot // AMD: scf.yield %[[DOT1]] // AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2 -// AMD-COUNT-2: triton_gpu.local_dealloc +// AMD-COUNT-2: ttg.local_dealloc // AMD: scf.yield %[[SEL1]] +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_nested + tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -246,9 +280,9 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -266,50 +300,66 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK-DAG: %[[B0:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.local_load %[[arg_b0]] +// CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] -// CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local // CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 // CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] -// CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] -// CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[NEXT_B:.*]] = ttg.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] // AMD-LABEL: tt.func @matmul_loop_single_pipeline // AMD: %[[LOAD_10:.*]] = tt.load %{{.*}} -// AMD: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] -// AMD: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// AMD: %[[CONVERT_LAYOUT_11:.*]] = ttg.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = ttg.local_alloc // AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] // AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] // AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) -// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] -// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] // AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} // AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] +// AMD: %[[LOCAL_LOAD_30:.*]] = ttg.local_load %[[ARG9]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] // AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} // AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] // AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_12]] + +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_single_pipeline +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -331,7 +381,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> @@ -341,7 +391,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> @@ -350,98 +400,125 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, } // CHECK-LABEL: tt.func @indirect_bmm_scalar -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] // CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} // AMD-LABEL: tt.func @indirect_bmm_scalar -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc -// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] -// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] -// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] -// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] -// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] -// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] -// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] -// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] -// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] -// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] -// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] -// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] -// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] -// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] -// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG13:.*]] = %[[LOAD_15]], %[[ARG14:.*]] = %[[LOAD_21]]) -// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[ARG7]] +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_12]] +// AMD: %[[MEMDESC_SUBVIEW_13:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_13]] +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_17:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_16]] +// AMD: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_15]], %[[CMPI_11]] +// AMD: %[[MULI_19:.*]] = arith.muli %{{.*}}, %[[LOAD_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[MULI_19]] +// AMD: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[SPLAT_20]] +// AMD: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_23:.*]] = tt.load %[[ADDPTR_21]], %[[SPLAT_22]] +// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_14]], %[[ARG9:.*]] = %[[ADDPTR_15]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[LOAD_17]], %[[ARG12:.*]] = %[[LOAD_23]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_13]]) +// AMD: %[[ADDI_41:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} +// AMD: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[ARG11]], %[[MEMDESC_SUBVIEW_44]] +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_45]] // AMD: %[[ADDPTR_46:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_48:.*]] = tt.load %[[ADDPTR_46]] -// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[MULI_50:.*]] = arith.muli %{{.*}}, %[[LOAD_49]] -// AMD: %[[SPLAT_51:.*]] = tt.splat %[[MULI_50]] -// AMD: %[[ADDPTR_52:.*]] = tt.addptr %{{.*}}, %[[SPLAT_51]] -// AMD: %[[LOAD_53:.*]] = tt.load %[[ADDPTR_52]] -// AMD: %[[ADDI_54:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_55:.*]] = arith.cmpi slt, %[[ADDI_54]], %{{.*}} -// AMD: %[[SELECT_56:.*]] = arith.select %[[CMPI_55]], %[[ADDI_54]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_57:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_56]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_57]] -// AMD: %[[MEMDESC_SUBVIEW_58:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_56]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_58]] -// AMD: scf.yield %[[DOT_45]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_56]], %[[MEMDESC_SUBVIEW_57]], %[[MEMDESC_SUBVIEW_58]], %[[LOAD_48]], %[[LOAD_53]] -// AMD: } -// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[LOCAL_LOAD_29]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 -// AMD: } -// AMD: %[[ADDI_31:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} -// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_33]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_34]] -// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_33]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_35]] -// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_34]] -// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_35]] -// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_36]] -// AMD: } -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] +// AMD: %[[LOCAL_LOAD_49:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[LOAD_50:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[MULI_51:.*]] = arith.muli %{{.*}}, %[[LOAD_50]] +// AMD: %[[SPLAT_52:.*]] = tt.splat %[[MULI_51]] +// AMD: %[[ADDPTR_53:.*]] = tt.addptr %{{.*}}, %[[SPLAT_52]] +// AMD: %[[LOAD_54:.*]] = tt.load %[[ADDPTR_53]] +// AMD: %[[LOCAL_LOAD_55:.*]] = ttg.local_load %[[ARG14]] +// AMD: %[[DOT_56:.*]] = tt.dot %[[LOCAL_LOAD_49]], %[[LOCAL_LOAD_55]], %[[ARG7]] +// AMD: scf.yield %[[DOT_56]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_43]], %[[LOAD_48]], %[[LOAD_54]], %[[MEMDESC_SUBVIEW_44]], %[[MEMDESC_SUBVIEW_45]] +// AMD: } +// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[ADDI_28:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// AMD: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %{{.*}}#4, %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %{{.*}}#5, %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[LOCAL_LOAD_33:.*]] = ttg.local_load %{{.*}}#6 +// AMD: %[[LOCAL_LOAD_34:.*]] = ttg.local_load %{{.*}}#7 +// AMD: %[[IF_35:.*]] = scf.if %[[CMPI_26]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_33]], %[[LOCAL_LOAD_34]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_35]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[LOCAL_LOAD_38:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_36]] +// AMD: } +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_1]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, @@ -462,8 +539,8 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 @@ -473,30 +550,32 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, } // CHECK-LABEL: tt.func @indirect_bmm_scalar_dist_one -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: scf.for %{{.*}} iter_args(%{{[^,]*}}, %{{[^,]*}}, %{{[^,]*}}, %[[IND_BUFFER_PREV:[^,]*]] = {{[^,]*}} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] // CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_PREV]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] // AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one // AMD-COUNT-4: tt.load // AMD: scf.for -// AMD: tt.dot // AMD: tt.load -// AMD: triton_gpu.local_store +// AMD: tt.dot +// AMD: ttg.local_store // AMD: scf.yield +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar_dist_one + tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -518,8 +597,8 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 @@ -529,30 +608,30 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, } // CHECK-LABEL: tt.func @indirect_bmm_vector -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.for // CHECK: tt.dot // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = triton_gpu.async_wait {{.*}} {num = 1 : i32} -// CHECK-DAG: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview -// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield // AMD-LABEL: tt.func @indirect_bmm_vector -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc // AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} // AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] @@ -560,40 +639,42 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} // AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] -// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} -// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] -// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] -// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] -// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] -// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] -// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] -// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] -// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_9]] +// AMD: %[[EXPAND_DIMS_11:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_12:.*]] = tt.broadcast %[[EXPAND_DIMS_11]] +// AMD: %[[MULI_13:.*]] = arith.muli %{{.*}}, %[[BROADCAST_12]] +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] // AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG13:.*]] = %[[LOAD_16]]) -// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] -// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] -// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} -// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] -// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] -// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] -// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] -// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] -// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} -// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_vector tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, @@ -616,8 +697,8 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -666,9 +747,9 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %114 = arith.cmpi "slt", %66, %113 : tensor<32x1xi32, #AL> %115 = tt.broadcast %114 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> %116 = tt.load %arg12, %115, %cst_0 : tensor<32x32x!tt.ptr, #AL> - %117 = triton_gpu.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> - %118 = triton_gpu.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %117 = ttg.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %118 = ttg.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -684,7 +765,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // COMMON-LABEL: tt.func @cross_iter_dep // TODO: enable pipelining with distance of 2 -// COMMON-NOT: triton_gpu.async_commit_group +// COMMON-NOT: ttg.async_commit_group // COMMON: scf.for // COMMON: scf.yield @@ -724,9 +805,9 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %148 = arith.cmpi "slt", %88, %147 : tensor<32x1xi32, #AL> %149 = tt.broadcast %148 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> %150 = tt.load %arg12, %149, %cst_1 : tensor<32x32x!tt.ptr, #AL> - %151 = triton_gpu.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> - %152 = triton_gpu.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %151 = ttg.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %152 = ttg.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -744,7 +825,6 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // COMMON: tt.expand_dims // COMMON: tt.expand_dims // COMMON: tt.expand_dims %arg5 -// COMMON-NEXT: tt.expand_dims %arg5 // COMMON: %[[PTR0:.*]] = tt.splat %arg6 // COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] // COMMON-NEXT: tt.load %[[PTR1]] @@ -753,10 +833,10 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { %23 = arith.constant 100 : index %c64 = arith.constant 64 : i64 - %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> %85 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> %86 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> %68 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> @@ -769,17 +849,17 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %90 = tt.splat %c64 : i64 -> tensor<32x128xi64, #BL> %92 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 %c0_index = arith.constant 0 : index - %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { + %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { %1750 = arith.subi %23, %arg19 : index %175 = arith.index_cast %1750 : index to i32 - %176 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %177 = tt.splat %175 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> - %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> - %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> - %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %176 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %177 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> + %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> + %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> %184 = arith.extsi %182 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> %185 = arith.extsi %183 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> %186 = arith.muli %184, %85 : tensor<1x32xi64, #AL> @@ -790,17 +870,17 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %191 = tt.addptr %arg20, %189 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> %192 = tt.broadcast %180 : tensor<1x32xi1, #AL> -> tensor<128x32xi1, #AL> %193 = tt.load %191, %192 : tensor<128x32x!tt.ptr, #AL> - %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %196 = tt.load %195 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> + %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %196 = tt.load %195 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> %197 = tt.addptr %arg22, %c32_i32 : !tt.ptr, i32 %198 = tt.broadcast %181 : tensor<32x1xi1, #BL> -> tensor<32x128xi1, #BL> %199 = tt.load %arg24, %198, %88 : tensor<32x128x!tt.ptr, #BL> - %200 = triton_gpu.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> - %201 = triton_gpu.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> - %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %200 = ttg.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> + %201 = ttg.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> + %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> - scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> + scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> } tt.return %91#3 : tensor<128x128xf32, #C> } @@ -808,12 +888,13 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> @@ -828,16 +909,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -846,15 +927,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -864,22 +945,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- // CHECK-LABEL: nested_loops -// CHECK: triton_gpu.local_alloc +// CHECK: ttg.local_alloc // CHECK: scf.for -// CHECK-NOT: triton_gpu.local_alloc +// CHECK-NOT: ttg.local_alloc // CHECK: scf.for // CHECK: scf.yield -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.async_commit_group +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group // CHECK: scf.yield // AMD-LABEL: tt.func public @nested_loops // AMD: scf.for -// AMD: triton_gpu.local_alloc -// AMD-NOT: triton_gpu.local_alloc +// AMD: ttg.local_alloc +// AMD-NOT: ttg.local_alloc // AMD: scf.for // AMD: scf.yield // AMD-DIS: scf.yield @@ -900,9 +981,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // For CUDA, we pipeline the inner loop first then pipeline the outer // loop to prefetch the async copy after the inner loop. // For HIP, we only pipeline the inner loop for now. -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> @@ -910,9 +992,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c1_i32 = arith.constant 1 : i32 %c32_i32 = arith.constant 32 : i32 %c10_i32 = arith.constant 10 : i32 - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %3 = arith.muli %2, %cst_0 : tensor<32x1xi32, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> @@ -921,15 +1003,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> scf.for %arg4 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { %9 = arith.muli %arg4, %c32_i32 : i32 - %10 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %11 = tt.splat %9 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %12 = arith.addi %10, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %13 = arith.addi %11, %1 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %10 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %11 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %10, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %11, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %15 = tt.broadcast %14 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %16 = tt.addptr %6, %15 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %17 = tt.load %16 : tensor<32x32x!tt.ptr, #blocked> - %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %19 = arith.muli %18, %cst_0 : tensor<32x1xi32, #blocked> %20 = tt.addptr %7, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> %21 = tt.broadcast %20 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> @@ -937,17 +1019,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %23 = tt.broadcast %22 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { %24 = arith.muli %arg5, %c32_i32 : i32 - %25 = tt.splat %24 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %26 = arith.addi %25, %0 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %25 = tt.splat %24 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %25, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %28 = tt.broadcast %27 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %29 = tt.addptr %21, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %30 = tt.load %29 : tensor<32x32x!tt.ptr, #blocked> - %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %32 = triton_gpu.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %31 = ttg.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %32 = ttg.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %34 = tt.addptr %23, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %35 = triton_gpu.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %35 = ttg.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %35 : tensor<32x32x!tt.ptr, #blocked> } } @@ -957,92 +1039,92 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$SHARED_LAYOUT:shared.*]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // CHECK-LABEL: tt.func @indirect_load_shared_layout // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -// CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] // CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] -// CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] -// CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} -// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // AMD-LABEL: tt.func @indirect_load_shared_layout -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] -// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] -// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} -// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] -// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] -// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] -// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] -// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] -// AMD: } -// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[IF_25:.*]] = scf.if %[[CMPI_21]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[LOCAL_LOAD_24]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] // AMD: } -// AMD: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}#1, %{{.*}} -// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]] -// AMD: %[[EXPAND_DIMS_29:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32} -// AMD: %[[BROADCAST_30:.*]] = tt.broadcast %[[EXPAND_DIMS_29]] -// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[BROADCAST_30]] -// AMD: %[[ADDPTR_32:.*]] = tt.addptr %{{.*}}, %[[MULI_31]] -// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_33]] -// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} -// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_25]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_40]] -// AMD: } -// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} - -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}#1, %{{.*}} +// AMD: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_24]] +// AMD: %[[LOCAL_LOAD_26:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[EXPAND_DIMS_27:.*]] = tt.expand_dims %{{.*}}#5 {axis = 1 : i32} +// AMD: %[[BROADCAST_28:.*]] = tt.broadcast %[[EXPAND_DIMS_27]] +// AMD: %[[MULI_29:.*]] = arith.muli %{{.*}}, %[[BROADCAST_28]] +// AMD: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_29]] +// AMD: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_31]] +// AMD: %[[LOCAL_LOAD_33:.*]] = ttg.local_load %{{.*}}#6 +// AMD: %[[IF_34:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_26]], %[[LOCAL_LOAD_33]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_40]] +// AMD: } +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1064,8 +1146,8 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -1079,25 +1161,25 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // ----- // CHECK-LABEL: @kernel_yield_constant -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview // CHECK: scf.for -// CHECK: triton_gpu.async_copy_global_to_local -// CHECK: triton_gpu.memdesc_subview +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview // CHECK: tt.return // AMD-LABEL: @kernel_yield_constant // AMD: tt.load -// AMD: triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store // AMD: scf.for // AMD: tt.load -// AMD: triton_gpu.memdesc_subview -// AMD: triton_gpu.local_store +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store // AMD: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> @@ -1106,12 +1188,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %c32_i32 = arith.constant 32 : i32 %c31_i32 = arith.constant 31 : i32 - %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %0 = tt.get_program_id x : i32 - %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %12 = arith.addi %arg4, %c31_i32 : i32 %13 = arith.divsi %12, %c32_i32 : i32 - %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %22 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %34 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %42 = scf.for %arg7 = %c0_i32 to %13 step %c1_i32 iter_args(%arg8 = %cst) -> (tensor<32x32xf32, #mma>) : i32 { @@ -1124,9 +1206,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %49 = arith.cmpi slt, %14, %48 : tensor<32x1xi32, #blocked> %50 = tt.broadcast %49 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> %51 = tt.load %46, %50, %cst_0 : tensor<32x32x!tt.ptr, #blocked> - %52 = triton_gpu.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %54 = triton_gpu.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %52 = ttg.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %54 = ttg.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %54 : tensor<32x32x!tt.ptr, #blocked> scf.yield %cst1 : tensor<32x32xf32, #mma> } @@ -1140,16 +1222,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @add_kernel // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 -// CHECK: %[[ABUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] -// CHECK: %[[B0BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] -// CHECK: %[[A1BUFFER:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] -// CHECK: %[[B1BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] -// CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: %[[A0BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] +// CHECK: %[[B0BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] +// CHECK: %[[A1BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] +// CHECK: %[[B1BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for // AMD-LABEL: tt.func public @add_kernel @@ -1165,8 +1247,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] // AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] // AMD: scf.for -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1201,78 +1283,96 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @nested_loops // CHECK: tt.addptr %{{.*}}, {{.*}} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: %[[BUFFER_1:.*]] = triton_gpu.local_alloc -// CHECK: %[[SUBVIEW_1:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_1:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_1]] -// CHECK: %[[SUBVIEW_2:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_2:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_2]] +// CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc +// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] +// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] // CHECK: scf.for // CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] -// CHECK: %[[BUFFER_2:.*]] = triton_gpu.local_alloc %[[LOAD_1]] -// CHECK: %[[TRANS:.*]] = tt.trans %[[BUFFER_2]] -// CHECK: %[[LOCAL_LOAD_1:.*]] = triton_gpu.local_load %[[TRANS]] -// CHECK: triton_gpu.async_wait -// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[BUFFER_2:.*]] = ttg.local_alloc %[[LOAD_1]] +// CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[BUFFER_2]] +// CHECK: %[[LOCAL_LOAD_1:.*]] = ttg.local_load %[[TRANS]] +// CHECK: ttg.async_wait +// CHECK: ttg.memdesc_subview %[[BUFFER_1]] // CHECK: scf.for -// CHECK: %[[LOCAL_LOAD_2:.*]] = triton_gpu.local_load +// CHECK: %[[LOCAL_LOAD_2:.*]] = ttg.local_load // CHECK: %[[DOT:.*]] = tt.dot %[[LOCAL_LOAD_2]], %[[LOCAL_LOAD_1]] -// CHECK: %[[CONVERT_LAYOUT_3:.*]] = triton_gpu.convert_layout %[[DOT]] -// CHECK: %[[SUBVIEW_4:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_3:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] -// CHECK: triton_gpu.async_commit_group %[[ASYNC_COPY_3]] -// CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[SUBVIEW_6:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_4:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask -// CHECK: %[[COMMIT_1:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_4]] -// CHECK: %[[SUBVIEW_7:.*]] = triton_gpu.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_5:.*]] = triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask -// CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] +// CHECK: %[[CONVERT_LAYOUT_3:.*]] = ttg.convert_layout %[[DOT]] +// CHECK: %[[SUBVIEW_4:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_3:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_3]] +// CHECK: ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[SUBVIEW_6:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_4:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask +// CHECK: %[[COMMIT_1:.*]] = ttg.async_commit_group %[[ASYNC_COPY_4]] +// CHECK: %[[SUBVIEW_7:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_5:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask +// CHECK: %[[COMMIT_2:.*]] = ttg.async_commit_group %[[ASYNC_COPY_5]] // CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] -// CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] +// CHECK: ttg.local_dealloc %[[BUFFER_1]] // AMD-LABEL: tt.func public @nested_loops -// AMD-NOT: triton_gpu.local_alloc +// AMD-NOT: ttg.local_alloc // AMD: scf.for -// AMD: triton_gpu.local_alloc +// AMD: ttg.local_alloc // AMD: scf.for -// AMD: triton_gpu.local_load +// AMD: ttg.local_load // AMD: tt.dot -// AMD: triton_gpu.local_store +// AMD: ttg.local_store // AMD: scf.yield -// AMD: triton_gpu.local_dealloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +// AMD: ttg.local_dealloc + +// AMD_PREFETCH-LABEL: tt.func public @nested_loops +// AMD_PREFETCH-NOT: ttg.local_alloc +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: ttg.local_dealloc + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> - %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> - %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> - %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> - %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = ttg.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #smem> + %12 = ttg.memdesc_trans %11 {order = array} : !ttg.memdesc<16x16xf32, #shared, #smem> -> !ttg.memdesc<16x16xf32, #shared1, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<16x16xf32, #shared1, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> - %17 = triton_gpu.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + %15 = ttg.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %17 = ttg.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> tt.store %9, %17 : tensor<16x16x!tt.ptr, #blocked> } } @@ -1283,14 +1383,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // CHECK-LABEL: @int4_matmul_ampere -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> +#blocked5 = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { tt.func public @int4_matmul_ampere( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32} @@ -1308,14 +1408,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %cst_2 = arith.constant dense<4> : tensor<64x256xi8, #blocked> %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #mma> - %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> %38 = tt.broadcast %36 : tensor<1x128xi32, #blocked1> -> tensor<16x128xi32, #blocked1> %40 = tt.splat %arg0 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked1> %41 = tt.addptr %40, %38 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> - %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> %47 = tt.broadcast %43 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> %50 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> @@ -1323,9 +1423,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // Check that both loads in the loop are pipelined. // CHECK: scf.for // CHECK-NOT: tt.load - // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: ttg.async_copy_global_to_local // CHECK-NOT: tt.load - // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: ttg.async_copy_global_to_local // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { @@ -1339,9 +1439,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> - %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> + %88 = ttg.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %89 = ttg.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> %91 = tt.addptr %arg11, %cst_0 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> %92 = tt.addptr %arg12, %cst : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> scf.yield %90, %91, %92 : tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked> @@ -1356,16 +1456,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // This test triggered some failure in the verifier, so we only // included a simple check for the kernel name. // COMMON-LABEL: @load_convert_layout -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> - -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1390,8 +1490,8 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> - %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> - %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> @@ -1407,18 +1507,18 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 // This test captured some ICE in MatmulLoopPipeline pass, so we only // included a simple check for the kernel name. // COMMON-LABEL: @matmul_indirect_pipeline -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c0_i32 = arith.constant 0 : i32 - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %4 = tt.broadcast %2 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> %5 = tt.broadcast %3 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %6 = arith.addi %4, %5 : tensor<32x32xi32, #blocked> @@ -1427,20 +1527,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %9 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> %10 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> %11 = tt.addptr %10, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { - %15 = tt.load %13 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %17 = tt.load %16 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> + %15 = tt.load %13 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %17 = tt.load %16 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> %19 = tt.broadcast %18 : tensor<1x32xf32, #blocked> -> tensor<32x32xf32, #blocked> %20 = arith.addf %9, %19 : tensor<32x32xf32, #blocked> - %21 = triton_gpu.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %22 = triton_gpu.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %21 = ttg.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %22 = ttg.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> } {tt.num_stages = 3 : i32} tt.return @@ -1450,21 +1550,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // COMMON-LABEL: @dont_pipeline_128x1 -// COMMON-NOT: local_load{{.*}}128x1 -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// AMD-NOT: local_load{{.*}}128x1 +// CHECK: local_load{{.*}}128x1 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c128_i32 = arith.constant 128 : i32 %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 - %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> - %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> - %161 = triton_gpu.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> + %161 = ttg.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> @@ -1472,17 +1573,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : ^bb0(%arg33: f32, %arg34: f32): %207 = arith.maxnumf %arg33, %arg34 : f32 tt.reduce.return %207 : f32 - }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %202 = triton_gpu.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %202 = ttg.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> - %203 = arith.constant dense<0.> : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %203 = arith.constant dense<0.> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> - scf.yield %175 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + scf.yield %175 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> } tt.return } @@ -1493,18 +1594,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that the dependencies across ops of different nesting does not cause crash or // incorrect schedule that fails to pipeline. // COMMON-LABEL: @matmul_nested_ops -// COMMON: triton_gpu.local_load - -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> -#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> -#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// COMMON: ttg.local_load + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, @@ -1531,7 +1632,7 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C>) { %cnd = arith.cmpi slt, %iv, %ext : index @@ -1542,7 +1643,7 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, scf.yield %a_ptr : tensor<128x32x!tt.ptr, #AL> } %a_ = tt.load %inc_a_ptr : tensor<128x32x!tt.ptr, #AL> - %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -1558,9 +1659,9 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // CHECK-LABEL: @masked_add_kernel // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> // CHECK: scf.for -// CHECK: %[[A:.*]] = triton_gpu.local_load +// CHECK: %[[A:.*]] = ttg.local_load // CHECK: arith.select {{.*}}, %[[A]], %[[CONSTANT]] -// CHECK: %[[B:.*]] = triton_gpu.local_load +// CHECK: %[[B:.*]] = ttg.local_load // CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] // AMD-LABEL: @masked_add_kernel @@ -1570,13 +1671,28 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: scf.for -// AMD: arith.select -// AMD: arith.addf // AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] - -#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// AMD: arith.addf +// AMD: arith.select +// AMD: scf.yield + +// AMD_PREFETCH-LABEL: @masked_add_kernel +// AMD_PREFETCH: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD_PREFETCH-COUNT-6: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: arith.addf +// AMD_PREFETCH: arith.select +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: tt.store + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 diff --git a/test/TritonGPU/loop-schedule.mlir b/test/TritonGPU/loop-schedule.mlir new file mode 100644 index 000000000000..afd4ec75db54 --- /dev/null +++ b/test/TritonGPU/loop-schedule.mlir @@ -0,0 +1,61 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#CLs0 = #ttg.slice<{parent=#C, dim=0}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABLE: @matmul_loop_load_acc +// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} +// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} +// CHECK: tt.load %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}, + %c_init: tensor<128x128xf32, #C>) -> tensor<128x128xf32, #C> { + + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #C> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #CLs0> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32, #CLs0> -> tensor<1x128xi32, #C> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32, #C> -> tensor<128x128xi32, #C> + %c_ptr_init = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr, #C>, tensor<128x128xi32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c_off = arith.constant dense<4> : tensor<128x128xi32, #C> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c_ptr = %c_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128x!tt.ptr, #C>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %c_ = tt.load %c_ptr : tensor<128x128x!tt.ptr, #C> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_c_ptr = tt.addptr %c_ptr, %c_off : tensor<128x128x!tt.ptr, #C>, tensor<128x128xi32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_c_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128x!tt.ptr, #C>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} +} diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir new file mode 100644 index 000000000000..f8042feee9bf --- /dev/null +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @softmax_kernel +tt.func public @softmax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = tt.get_num_programs x : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> + %3 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked> + // CHECK: [[MASK:%.*]] = arith.cmpi slt, {{.*}} tensor<128xi32, + %4 = arith.cmpi slt, %2, %3 : tensor<128xi32, #blocked> + // CHECK: scf.for + scf.for %arg6 = %0 to %arg4 step %1 : i32 { + %5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + // CHECK: [[RESULT:%.*]] = ttg.local_load + // CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst + %7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked> + %8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %9 = tt.addptr %8, %2 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + tt.store %9, %7, %4 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr, #blocked> + } {tt.num_stages = 2 : i32} + tt.return +} + +} diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 8676987ea0b0..50b90037e7ae 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=target=cuda:80 -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -canonicalize -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=target=cuda:80 -tritongpu-remove-layout-conversions -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize -test-print-allocation 2>&1 | FileCheck %s // CHECK: offset = 0, size = 32768 // CHECK: offset = 32768, size = 32768 diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 9184a5312020..0262bad35227 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -1,35 +1,61 @@ // RUN: triton-opt --split-input-file %s | FileCheck %s -// CHECK: #[[$WMMA_GEN1:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 1{{.*}}> -// CHECK: #[[$WMMA_GEN2:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 2{{.*}}> -#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +// CHECK: #[[$WMMA_GEN1:.*]] = #ttg.amd_wmma<{{.*}}version = 1{{.*}}> +// CHECK: #[[$WMMA_GEN2:.*]] = #ttg.amd_wmma<{{.*}}version = 2{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_layout tt.func @wmma_layout(%0: tensor<16x16xf16, #blocked>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> tt.return } // CHECK-LABEL: wmma_dot_op_layout - tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> + tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_gen2_layout tt.func @wmma_gen2_layout(%0: tensor<16x16xf16, #blocked>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> tt.return } // CHECK-LABEL: wmma_gen2_dot_op_layout - tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> + tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> + tt.return + } +} +// ----- + +#blocked= #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[$LINEAR:.*]] = #ttg.linear<{{.*}}> + +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @blocked_to_linear + tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { + // The layout is the basic layout generated by DecomposeScaledBlocked + %output = ttg.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> + tt.return + } +} + +// ----- + +#shared0 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: memdesc + // CHECK-SAME: !ttg.memdesc<1x64x16xf16, #{{.+}}> + tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0, #smem>) { tt.return } } diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 544299867190..25fa8fbcb596 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -10,11 +10,11 @@ // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @negative_zero_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -23,16 +23,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -40,11 +40,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -63,11 +63,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.addf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @positive_zero_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -76,16 +76,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -93,11 +93,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -112,11 +112,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf // CHECK: arith.addf // CHECK-NEXT: scf.yield -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] -#blocked3d = #triton_gpu.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#slice2d = #triton_gpu.slice<{dim = 2, parent = #blocked3d}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked3d = #ttg.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#slice2d = #ttg.slice<{dim = 2, parent = #blocked3d}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @slice_layout( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -125,16 +125,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> -> tensor<1x128xi32, #slice2d> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> -> tensor<1x128xi32, #slice2d> %31 = tt.broadcast %30 : tensor<1x128xi32, #slice2d> -> tensor<32x128xi32, #slice2d> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #slice2d>, tensor<32x128xi32, #slice2d> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #slice2d> @@ -142,11 +142,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -161,11 +161,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.addf // CHECK: arith.addf // CHECK-NEXT: scf.yield -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mma_layout( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -174,16 +174,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> %31 = tt.broadcast %30 : tensor<1x128xi32, #mma> -> tensor<32x128xi32, #mma> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #mma>, tensor<32x128xi32, #mma> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #mma> @@ -191,11 +191,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.addf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -214,11 +214,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.maximumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @max_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -227,16 +227,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -244,11 +244,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -268,11 +268,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.maximumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.maximumf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @max_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -281,16 +281,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -298,11 +298,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -321,11 +321,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.minimumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @min_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -334,16 +334,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -351,11 +351,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.minimumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -375,11 +375,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.minimumf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.minimumf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @min_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -388,16 +388,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -405,11 +405,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.minimumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -428,11 +428,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.mulf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mul_reduce( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -441,16 +441,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -458,11 +458,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.mulf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -482,11 +482,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-NEXT: scf.yield // CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> // CHECK: arith.mulf -// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] // CHECK: arith.mulf %[[CVT_OUTPUT]], %[[CST]] -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @mul_reduce_zero_int_accumulator( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -495,16 +495,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -512,11 +512,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.mulf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -534,9 +534,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: arith.maximumf // CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] // CHECK-NEXT: scf.yield -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @remains_unchanged( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -545,16 +545,16 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %11: i32 {tt.divisibility = 16 : i32}, %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %c128_i32 = arith.constant 128 : i32 %1 = tt.get_program_id y : i32 %2 = tt.get_num_programs y : i32 - %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { %27 = arith.muli %arg3, %c128_i32 : i32 - %28 = tt.splat %27 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> @@ -563,11 +563,11 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : ^bb0(%arg5: f32, %arg6: f32): %36 = arith.maximumf %arg5, %arg6 : f32 tt.reduce.return %36 : f32 - }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - %26 = triton_gpu.convert_layout %19 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> tt.return } @@ -575,32 +575,32 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-DAG: #[[$BLOCK0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCK1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> -// CHECK-DAG: #[[$BLOCK2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK-DAG: #[[$BLOCK0:.+]] = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK2:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> // CHECK-LABEL: optimize_view_layout // CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> // CHECK: "tt.reduce"(%[[C]]) -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> { %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = arith.maximumf %arg1, %arg2 : f32 tt.reduce.return %2 : f32 - }) : (tensor<64x16xf32, #blocked1>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - tt.return %1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + }) : (tensor<64x16xf32, #blocked1>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + tt.return %1 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> -#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func public @reduce_for_arg(%arg: tensor<64x128xf32, #blocked>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c128_i32 = arith.constant 128 : i32 @@ -611,8 +611,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : ^bb0(%arg31: f32, %arg32: f32): %160 = arith.maxnumf %arg31, %arg32 : f32 tt.reduce.return %160 : f32 - }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %75 = triton_gpu.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> + }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %75 = ttg.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> %80 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> %81 = tt.addptr %80, %79 : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> diff --git a/test/TritonGPU/optimize_epilogue.mlir b/test/TritonGPU/optimize_epilogue.mlir index d990b14e8507..142ec762fb18 100644 --- a/test/TritonGPU/optimize_epilogue.mlir +++ b/test/TritonGPU/optimize_epilogue.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-optimize-epilogue | FileCheck --check-prefixes=GCN %s -#mfma = #triton_gpu.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // GCN-LABEL: mfma_epilogue_simple // CHECK-LABEL: mfma_epilogue_simple tt.func public @mfma_epilogue_simple(%data: tensor<64x64xf16, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { - // GCN: [[PTR:%[a-z0-9]+]] = triton_gpu.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> - %converted_data = triton_gpu.convert_layout %data : tensor<64x64xf16, #mfma> -> tensor<64x64xf16, #blocked> + %converted_data = ttg.convert_layout %data : tensor<64x64xf16, #mfma> -> tensor<64x64xf16, #blocked> tt.store %ptr, %converted_data : tensor<64x64x!tt.ptr, #blocked> tt.return } @@ -16,15 +16,15 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // ----- -#mfma = #triton_gpu.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> -#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { // GCN-LABEL: mfma_epilogue_chained_elementwise // CHECK-LABEL: mfma_epilogue_chained_elementwise tt.func public @mfma_epilogue_chained_elementwise(%data: tensor<64x64xf32, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { - // GCN: [[PTR:%[a-z0-9]+]] = triton_gpu.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> - %converted_data = triton_gpu.convert_layout %data : tensor<64x64xf32, #mfma> -> tensor<64x64xf32, #blocked> + %converted_data = ttg.convert_layout %data : tensor<64x64xf32, #mfma> -> tensor<64x64xf32, #blocked> %trunked = arith.truncf %converted_data : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> tt.store %ptr, %trunked : tensor<64x64x!tt.ptr, #blocked> tt.return diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir new file mode 100644 index 000000000000..9ff318b77983 --- /dev/null +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -0,0 +1,376 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-assign-latencies=num-stages=3 -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared2 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @default_stages +tt.func @default_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @small_load +// We should *not* assign latency to the load of b_ptr. +tt.func @small_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} + // CHECK-NOT: tt.latency + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @load_into_shared +tt.func @load_into_shared(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @load_into_shared_incompat_layout +tt.func @load_into_shared_incompat_layout(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // CHECK: tt.load + // CHECK-NOT: {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @mixed_loads +tt.func @mixed_loads(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @per_loop_stages +tt.func @per_loop_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> (tensor<128x128xf32, #C>, tensor<128x128xf32, #C>) { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop_cust_stages:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 4 : i32} + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop_cust_stages#2, %loop#2: tensor<128x128xf32, #C>, tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_cust_stages +tt.func @indirect_load_cust_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 5 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_few_stages +tt.func @indirect_load_few_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load + // CHECK-NOT: tt.latency + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 2 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @non_dot_pipeline +tt.func @non_dot_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } {tt.num_stages = 3 : i32} + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @no_pipeline +tt.func @no_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use_cust_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 3 : i32} + tt.return %loop#2: tensor<128x128xf32, #C> +} + +} diff --git a/test/TritonGPU/pipeline-schedule-loop.mlir b/test/TritonGPU/pipeline-schedule-loop.mlir new file mode 100644 index 000000000000..bd66562d528b --- /dev/null +++ b/test/TritonGPU/pipeline-schedule-loop.mlir @@ -0,0 +1,337 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-schedule-loop -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_dep +tt.func @one_dep(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps +tt.func @parallel_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven1 +tt.func @parallel_deps_uneven1(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %b = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven2 +tt.func @parallel_deps_uneven2(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %a = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @direct_deps +tt.func @direct_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_next {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @dist1_deps +tt.func @dist1_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @prologue_if +tt.func @prologue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: scf.if + // CHECK: {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %a_ptr = scf.if %cnd -> tensor<128x32x!tt.ptr, #A> { + %a_ptr_ret = tt.addptr %a_ptr_init, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %a_ptr_ret : tensor<128x32x!tt.ptr, #A> + } else { + scf.yield %a_ptr_init : tensor<128x32x!tt.ptr, #A> + } + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_epilogue_if +tt.func @independent_epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %a_ptr_init, %init : tensor<128x32x!tt.ptr, #A> + } + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_last_stage +tt.func @independent_last_stage(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %acc2 = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res2 = arith.addf %acc2, %init : tensor<128x32xf16, #A> + scf.yield %res, %res2 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @basic_pipeline +tt.func @basic_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @unpipelined_load +tt.func @unpipelined_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // load below should be in the same stage as tt.dot (not pipelined) + // CHECK: tt.load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below should be scheduled to the last stage + // CHECK: tt.addptr {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @epilogue_if +tt.func @epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>, + %c_ptr_store : tensor<128x128x!tt.ptr, #C>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %c_ptr_store, %c : tensor<128x128x!tt.ptr, #C> + } + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: arith.mulf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + %a_off = tt.load %a_ind_ptr {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below scheduled by scheduleDependencies to the same stage as tt.load that is using it + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ = tt.load %next_a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %b_ = tt.load %next_b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} +} diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 9fbc540b92a6..208516b3bfab 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -2,38 +2,38 @@ // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> - +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] -// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] -// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> @@ -48,24 +48,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !tt.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -75,20 +75,20 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // matmul: 128x16 @ 16x128 -> 128x128 // CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x16x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<16x128x!tt.ptr, #BL> @@ -103,24 +103,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<16x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !tt.memdesc<128x16xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !tt.memdesc<16x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !tt.memdesc<128x16xf8E5M2, #A>, !tt.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x16xf8E5M2, #A> -> tensor<128x16xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x16xf8E5M2, #A, #smem> -> tensor<128x16xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x16xf8E5M2, #A_OP> -> tensor<128x16xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !tt.memdesc<16x128xf16, #B> -> tensor<16x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<16x128xf16, #B, #smem> -> tensor<16x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x16xf16, #A_OP> * tensor<16x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x16x!tt.ptr, #AL>, tensor<128x16xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<16x128x!tt.ptr, #BL>, tensor<16x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !tt.memdesc<128x16xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !tt.memdesc<16x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !tt.memdesc<128x16xf8E5M2, #A>, !tt.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -132,10 +132,10 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK: scf.if // CHECK: tt.store // CHECK-NOT: scf.yield -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { tt.func @matmul_loop_yield_no_operand(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c32_i32 = arith.constant 32 : i32 @@ -155,19 +155,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %10 = arith.remsi %9, %2 : i32 %11 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> %12 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %14 = triton_gpu.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %14 = ttg.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %16 = arith.cmpi sgt, %10, %c0_i32 : i32 %17 = scf.if %16 -> (tensor<32x32xf32, #mma>) { - %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> scf.yield %21 : tensor<32x32xf32, #mma> } else { scf.yield %15 : tensor<32x32xf32, #mma> } %18 = tt.splat %arg5 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> %19 = arith.truncf %17 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> - %20 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> + %20 = ttg.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> tt.store %18, %20 : tensor<32x32x!tt.ptr, #blocked1> } tt.return @@ -176,37 +176,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> -#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> -#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed_amd // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 -// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] -// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] -// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] -// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] -// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] // CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] -// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] -// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] -// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] -// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] // CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] -module attributes { "triton_gpu.num-warps" = 4 : i32 } { +module attributes { "ttg.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> @@ -221,27 +222,25 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = triton_gpu.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = triton_gpu.local_load %b : !tt.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module - -// ----- diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index 9fca92c9b099..e293ab724847 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -1,14 +1,14 @@ // RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s -// CHECK: #[[$SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} +// CHECK: #[[$SHARED:.*]] = #ttg.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} // CHECK-LABEL: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[$SHARED]], #triton_gpu.shared_memory> +// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #smem> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %0 = ttg.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -16,13 +16,13 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- // CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 -// CHECK-NOT: triton_gpu.local_alloc -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.local_alloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } @@ -30,13 +30,13 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- // CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 -// CHECK-NOT: triton_gpu.local_alloc -// CHECK: triton_gpu.convert_layout -// CHECK-NOT: triton_gpu.local_alloc -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.target" = "hip:gfx940", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "hip:gfx940", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { - %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> tt.return } } diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index dff1e6b60f8c..700ed22be2b1 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -2,21 +2,22 @@ // check that we don't hoist convert_layout above its operand definition. // CHECK-LABEL: convert_cannot_hoist -// CHECK: %[[CVTS:.+]] = triton_gpu.local_alloc -// CHECK: triton_gpu.local_load %[[CVTS]] +// CHECK: %[[CVTS:.+]] = ttg.local_alloc +// CHECK: ttg.local_load %[[CVTS]] // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_hoist(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %10 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -25,21 +26,22 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: sink_convert_dealloc -// CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: %3 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: %3 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> - %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> - triton_gpu.async_wait {num = 0 : i32} - triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> - triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -48,24 +50,25 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- // CHECK-LABEL: sink_convert_idx_1 -// CHECK: triton_gpu.local_load %{{.*}} : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -// CHECK: triton_gpu.local_load %{{.*}} : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_idx_1(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = triton_gpu.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %BD = triton_gpu.local_load %BS : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS = triton_gpu.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %AD = triton_gpu.local_load %AS : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } @@ -75,28 +78,29 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // check that we don't sink convert_layout if it has multi users // CHECK-LABEL: convert_cannot_sink -// CHECK: triton_gpu.local_load %{{.*}} : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -// CHECK: triton_gpu.local_load %{{.*}} : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -// CHECK: triton_gpu.local_load %{{.*}} : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_sink(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = triton_gpu.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %BD = triton_gpu.local_load %BS : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A0 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS0 = triton_gpu.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %AD0 = triton_gpu.local_load %AS0 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %AS0 = ttg.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD0 = ttg.local_load %AS0 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %A1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS1 = triton_gpu.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> - %AD1 = triton_gpu.local_load %AS1 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %AS1 = ttg.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD1 = ttg.local_load %AS1 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> tt.return } } diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir new file mode 100644 index 000000000000..2e95f5024f41 --- /dev/null +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -0,0 +1,180 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s +// CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_8:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_11:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_15:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_20:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_3]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_22:.*]] = arith.divsi %[[VAL_21]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_4]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_24:.*]] = arith.divsi %[[VAL_23]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_24]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_26]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_22]], %[[VAL_27]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.minsi %[[VAL_28]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.remsi %[[VAL_20]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_27]], %[[VAL_30]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.remsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_33:.*]] = arith.divsi %[[VAL_32]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_34:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_38:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_37]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_31]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_40:.*]] = arith.muli %[[VAL_33]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.divsi %[[VAL_41]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_46]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_47]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_48]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_52:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_52]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_54:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_54]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_53]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_55:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_56]], 49152, %[[VAL_55]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_57:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_58:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_58]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_59:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_60:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_60]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_59]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_61:.*]]:5 = scf.for %[[VAL_62:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_63:.*]] = %[[VAL_19]], %[[VAL_64:.*]] = %[[VAL_13]], %[[VAL_65:.*]] = %[[VAL_15]], %[[VAL_66:.*]] = %[[VAL_8]], %[[VAL_67:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_68:.*]] = arith.subi %[[VAL_42]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.cmpi slt, %[[VAL_62]], %[[VAL_68]] : i32 +// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_66]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_71:.*]] = arith.cmpi slt, %[[VAL_70]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_71]], %[[VAL_70]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_73:.*]] = arith.xori %[[VAL_67]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_67]], %[[VAL_73]] : i32 +// CHECK: %[[VAL_75:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_72]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_75]], %[[VAL_74]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_76:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_77:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_78:.*]] = ttg.memdesc_trans %[[VAL_76]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_79:.*]] = ttng.warp_group_dot %[[VAL_77]], %[[VAL_78]], %[[VAL_63]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_80:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_79]], %[[VAL_77]], %[[VAL_78]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_64]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_65]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi slt, %[[VAL_82]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_82]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_85:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_84]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_85]], 49152, %[[VAL_69]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_87:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_87]]{{\[}}%[[VAL_39]], %[[VAL_81]]] %[[VAL_86]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_89:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_89]]{{\[}}%[[VAL_40]], %[[VAL_81]]] %[[VAL_88]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_81]], %[[VAL_84]], %[[VAL_72]], %[[VAL_74]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_90:.*]] = ttng.warp_group_dot_wait %[[VAL_91:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_92:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_93]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_94]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_95]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_96:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_97:.*]] = ttg.convert_layout %[[VAL_96]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.experimental_descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_97]] : !tt.tensordesc>, tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in new file mode 100644 index 000000000000..4ab61167755c --- /dev/null +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in @@ -0,0 +1,57 @@ +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/test/TritonGPU/tritongpu_ops.mlir b/test/TritonGPU/tritongpu_ops.mlir deleted file mode 100644 index d5c6a52e8eca..000000000000 --- a/test/TritonGPU/tritongpu_ops.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: triton-opt %s | triton-opt | FileCheck %s - -#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> - -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: memdesc - // CHECK-SAME: !tt.memdesc<1x64x16xf16, #{{.+}}> - tt.func @memdesc(%d : !tt.memdesc<1x64x16xf16, #shared0>) { - tt.return - } -} diff --git a/test/TritonGPU/verify-blocked-layout.mlir b/test/TritonGPU/verify-blocked-layout.mlir index ec39b26d10cd..3c1d016cd5e8 100644 --- a/test/TritonGPU/verify-blocked-layout.mlir +++ b/test/TritonGPU/verify-blocked-layout.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[16, 1], warpsPerCTA=[4, 1], @@ -10,9 +10,9 @@ CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{threads per warp}} @@ -23,7 +23,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 2], @@ -33,9 +33,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{warps per CTA}} @@ -46,7 +46,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -56,9 +56,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // expected-error @+1 {{CTAs per CGA}} @@ -69,7 +69,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -79,9 +79,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: !tt.ptr) { // Note it's a 3d tensor here, but #blocked is 2D. @@ -93,7 +93,7 @@ module attributes { // ----- -#blocked = #triton_gpu.blocked<{ +#blocked = #ttg.blocked<{ sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], @@ -103,9 +103,9 @@ module attributes { CTAOrder=[0, 1] }> module attributes { - "triton_gpu.num-warps" = 4 : i32, - "triton_gpu.num-ctas" = 2 : i32, - "triton_gpu.threads-per-warp" = 32 : i32 + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 } { tt.func public @fn(%arg0: tensor<8xf32, #blocked>) { // expected-error @+1 {{rank}} diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 358f53fd7cd6..a042e282b374 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -1,25 +1,27 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering --allocate-shared-memory -test-print-membar | FileCheck %s -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: inval_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -28,18 +30,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.inval_barrier %alloc : !ttg.memdesc<1xi64, #shared0, #smem, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: barrier_expect // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -48,18 +51,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #smem, mutable> tt.return } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier // CHECK: local_alloc // CHECK-NEXT: gpu.barrier @@ -68,9 +72,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.wait_barrier %alloc, %phase : <1xi64, #shared0, #smem, mutable> tt.return } } @@ -78,10 +82,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @tma_load(%arg0: !tt.ptr, %arg1: i32) -> tensor<128x64xf16, #blocked0> { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { // CHECK-LABEL: tma_load // CHECK: local_dealloc // CHECK-NEXT: local_alloc @@ -89,28 +94,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } } // ----- -#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store -// CHECK: triton_gpu.local_alloc -// CHECK-NEXT: triton_gpu.local_dealloc +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.local_dealloc // CHECK-NEXT: gpu.barrier -// CHECK-NEXT: triton_gpu.local_alloc - tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { +// CHECK-NEXT: ttg.local_alloc + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> tt.return } } diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index 8b067a260bd6..dbde678e550d 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -1,31 +1,54 @@ // RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_load -// CHECK: triton_gpu.local_alloc : () -// CHECK: triton_gpu.local_alloc : () -// CHECK: triton_nvidia_gpu.init_barrier -// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local -// CHECK: triton_nvidia_gpu.wait_barrier -// CHECK: triton_nvidia_gpu.inval_barrier -// CHECK: triton_gpu.local_load - tt.func public @tma_load(%arg0: !tt.ptr, %arg1: i32) -> tensor<128x64xf16, #blocked> { - %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked> +// CHECK: ttg.local_alloc : () +// CHECK: ttg.local_alloc : () +// CHECK: ttng.init_barrier +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK: ttng.inval_barrier +// CHECK: ttg.local_load + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked> { + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> tt.return %l : tensor<128x64xf16, #blocked> } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store -// CHECK: triton_gpu.local_alloc -// CHECK: triton_nvidia_gpu.fence_async_shared {bCluster = false} -// CHECK: triton_nvidia_gpu.async_tma_copy_local_to_global - tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { - tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked> +// CHECK: ttg.local_alloc +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_local_to_global + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked> tt.return } } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: make_tensor_descriptor + // CHECK: %0 = arith.extsi %arg2 : i32 to i64 + // CHECK: %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + // CHECK: %2 = arith.shrsi %0, %c4_i64 : i64 + // CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%2], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () + // CHECK: tt.experimental_tensormap_fenceproxy_acquire %1 : !tt.ptr + // CHECK: tt.reinterpret_tensor_descriptor %1 : !tt.ptr to !tt.tensordesc> + tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc> { + %c1_i64 = arith.constant 1 : i64 + %cst = arith.constant dense<32> : tensor<8x1xi32> + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = arith.extsi %arg2 : i32 to i64 + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr, !tt.tensordesc> + tt.return %1 : !tt.tensordesc> + } +} diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 772e0258bf78..e7245e75cbbf 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -5,21 +5,42 @@ using namespace mlir; namespace { +unsigned getScratchSize128(Operation *) { return 128; } + +enum class GetScratchSizeFunction { + None, + ValidConstant, +}; + struct TestAllocationPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + TestAllocationPass() = default; + TestAllocationPass(const TestAllocationPass &other) + : PassWrapper>(other) {} + StringRef getArgument() const final { return "test-print-allocation"; } StringRef getDescription() const final { return "print the result of the allocation pass"; } + ModuleAllocation getModuleAllocation() { + switch (getScratchSizeFunction) { + case GetScratchSizeFunction::None: + return {getOperation()}; + case GetScratchSizeFunction::ValidConstant: + return {getOperation(), getScratchSize128}; + } + llvm_unreachable("Unhandled case"); + } + void runOnOperation() override { auto &os = llvm::errs(); ModuleOp moduleOp = getOperation(); // Convert to std::string can remove quotes from opName - ModuleAllocation moduleAllocation(moduleOp); + ModuleAllocation moduleAllocation = getModuleAllocation(); moduleOp.walk([&](triton::FuncOp funcOp) { auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); os << opName << "\n"; @@ -48,6 +69,15 @@ struct TestAllocationPass os << "size = " << allocation->getSharedMemorySize() << "\n"; }); } + + Option getScratchSizeFunction{ + *this, "get-scratch-size-function", + llvm::cl::desc("Custom scratch size function to use"), + llvm::cl::init(GetScratchSizeFunction::None), + llvm::cl::values( + clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"), + clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant", + "ValidConstant"))}; }; } // namespace diff --git a/test/lib/Instrumentation/GPUHello.cpp b/test/lib/Instrumentation/GPUHello.cpp index 3bee8ce90ced..5c71857c8f36 100644 --- a/test/lib/Instrumentation/GPUHello.cpp +++ b/test/lib/Instrumentation/GPUHello.cpp @@ -61,7 +61,7 @@ bool GpuHello::runOnModule(Module &module) { PassPluginLibraryInfo getPassPluginInfo() { const auto callback = [](PassBuilder &pb) { - pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto) { + pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) { mpm.addPass(GpuHello()); return true; }); diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 4053a8e7df56..fd1fb486dd3c 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -2,7 +2,7 @@ import sys -config.triton_obj_root = "@TRITON_BINARY_DIR@" +config.triton_obj_root = "@triton_BINARY_DIR@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" @@ -20,4 +20,4 @@ import lit.llvm lit.llvm.initialize(lit_config, config) # Let the main config do the real work -lit_config.load_config(config, "@TRITON_SOURCE_DIR@/test/lit.cfg.py") +lit_config.load_config(config, "@triton_SOURCE_DIR@/test/lit.cfg.py") diff --git a/third_party/amd/CMakeLists.txt b/third_party/amd/CMakeLists.txt index 8228c3d39111..a09bab8e1c4c 100644 --- a/third_party/amd/CMakeLists.txt +++ b/third_party/amd/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) + target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers) endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index a53a06dd4248..81b07f2e7d86 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -48,11 +48,23 @@ class HIPOptions: backend_name: str = 'hip' # The following option provides hints to the AMDGPU backend regarding instruction scheduling - # for all `tt.dot` operations in a kernel. The "default" variant preserves the default + # for all `tt.dot` operations in a kernel. The "none" variant preserves the default # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy. # The option is experimental and may change at any time regarding its semantics and/or may # be gone entirely anytime. - instruction_sched_variant: str = 'default' + # + # Current experimental scheduling variants: + # + # llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's + # k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels". + # llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's + # k-loop; i.e., "interleave DS and MFMA instructions for single wave small + # GEMM kernels.". + # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable + # Kernel library. Note, this variant requires the use of buffer load/store ops + # and a special software pipelining style - i.e., 1x LDS and 1x register + # prefetch buffers for each GEMM tile. + instruction_sched_variant: str = 'none' def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' @@ -127,7 +139,7 @@ def parse_options(self, opts) -> Any: if "supported_fp8_dtypes" not in opts: supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes) if self.target.arch in ('gfx940', 'gfx941', 'gfx942'): - supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'}) + supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'}) args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) if "enable_fp_fusion" not in opts: @@ -189,8 +201,8 @@ def make_ttir(mod, metadata, options): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_rewrite_tensor_pointer(pm) - passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) passes.common.add_licm(pm) @@ -215,13 +227,21 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) + + stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" + use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" + + # The `local-prefetch` scheduling variant requires turning on buffer ops. + if options.instruction_sched_variant == "local-prefetch": + stream_prefetch = use_buffer_ops = True + if amd.has_matrix_core_feature(options.arch): assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " "We used to trigger software pipelining with " "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") - amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) + amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) @@ -229,7 +249,11 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_reduce_data_duplication(pm) if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) - amd.passes.ttgpuir.add_canonicalize_pointers(pm) + + if use_buffer_ops: + amd.passes.ttgpuir.add_canonicalize_pointers(pm) + passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -271,15 +295,11 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages, + options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) - # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block - # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR - # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration - # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need - # for conditional branching around memory accesses. - amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm) + amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) pm.run(mod) # LLVM-IR (MLIR) -> LLVM-IR (LLVM) @@ -319,9 +339,12 @@ def make_llir(src, metadata, options): llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion) # Get some metadata - metadata["shared"] = src.get_int_attr("triton_gpu.shared") + metadata["shared"] = src.get_int_attr("ttg.shared") amd.cleanup_bitcode_metadata(llvm_mod) + # Disable inlining of print related functions, + # because inlining of these function could slow down compilation significantly + amd.disable_print_inline(llvm_mod) return str(llvm_mod) @staticmethod diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 6e1a368bf8cb..99e5509eca8d 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -220,7 +220,7 @@ def format_of(ty): "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", @@ -234,7 +234,8 @@ def format_of(ty): libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code - params = [i for i in signature.keys() if i not in constants] + params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params.append("&global_scratch") src = f""" #define __HIP_PLATFORM_AMD__ #include @@ -330,7 +331,8 @@ def format_of(ty): static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ // printf("_launch hip kernel\\n"); - void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + hipDeviceptr_t global_scratch = 0; + void *params[] = {{ {', '.join(params)} }}; if (gridX*gridY*gridZ > 0) {{ HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); }} diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h b/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h index 2a5cf48a0397..c4837ad64c4d 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h @@ -266,14 +266,14 @@ __device__ static inline int __mul24(int x, int y) { } __device__ static inline long long __mul64hi(long long int x, long long int y) { - ulong x0 = (ulong)x & 0xffffffffUL; - long x1 = x >> 32; - ulong y0 = (ulong)y & 0xffffffffUL; - long y1 = y >> 32; - ulong z0 = x0*y0; - long t = x1*y0 + (z0 >> 32); - long z1 = t & 0xffffffffL; - long z2 = t >> 32; + unsigned long long x0 = (unsigned long long)x & 0xffffffffUL; + long long x1 = x >> 32; + unsigned long long y0 = (unsigned long long)y & 0xffffffffUL; + long long y1 = y >> 32; + unsigned long long z0 = x0*y0; + long long t = x1*y0 + (z0 >> 32); + long long z1 = t & 0xffffffffL; + long long z2 = t >> 32; z1 = x0*y1 + z1; return x1*y1 + z2 + (z1 >> 32); } @@ -300,14 +300,14 @@ __device__ static inline int __umul24(unsigned int x, unsigned int y) { __device__ static inline unsigned long long __umul64hi(unsigned long long int x, unsigned long long int y) { - ulong x0 = x & 0xffffffffUL; - ulong x1 = x >> 32; - ulong y0 = y & 0xffffffffUL; - ulong y1 = y >> 32; - ulong z0 = x0*y0; - ulong t = x1*y0 + (z0 >> 32); - ulong z1 = t & 0xffffffffUL; - ulong z2 = t >> 32; + unsigned long long x0 = x & 0xffffffffUL; + unsigned long long x1 = x >> 32; + unsigned long long y0 = y & 0xffffffffUL; + unsigned long long y1 = y >> 32; + unsigned long long z0 = x0*y0; + unsigned long long t = x1*y0 + (z0 >> 32); + unsigned long long z1 = t & 0xffffffffUL; + unsigned long long z2 = t >> 32; z1 = x0*y1 + z1; return x1*y1 + z2 + (z1 >> 32); } @@ -322,11 +322,6 @@ __device__ static inline unsigned int __usad(unsigned int x, unsigned int y, uns return __ockl_sadd_u32(x, y, z); } -__device__ static inline unsigned int __lane_id() { - return __builtin_amdgcn_mbcnt_hi( - -1, __builtin_amdgcn_mbcnt_lo(-1, 0)); -} - __device__ static inline unsigned int __mbcnt_lo(unsigned int x, unsigned int y) {return __builtin_amdgcn_mbcnt_lo(x,y);}; @@ -339,6 +334,7 @@ HIP specific device functions #if !defined(__HIPCC_RTC__) #include "amd_warp_functions.h" +#include "amd_warp_sync_functions.h" #endif #define MASK1 0x00ff00ff @@ -687,34 +683,6 @@ void __named_sync() { __builtin_amdgcn_s_barrier(); } #endif // __HIP_DEVICE_COMPILE__ -// warp vote function __all __any __ballot -__device__ -inline -int __all(int predicate) { - return __ockl_wfall_i32(predicate); -} - -__device__ -inline -int __any(int predicate) { - return __ockl_wfany_i32(predicate); -} - -// XXX from llvm/include/llvm/IR/InstrTypes.h -#define ICMP_NE 33 - -__device__ -inline -unsigned long long int __ballot(int predicate) { - return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); -} - -__device__ -inline -unsigned long long int __ballot64(int predicate) { - return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); -} - // hip.amdgcn.bc - lanemask __device__ inline @@ -877,6 +845,10 @@ int __syncthreads_or(int predicate) #if (defined(__GFX10__) || defined(__GFX11__)) #define HW_ID_WGP_ID_SIZE 4 #define HW_ID_WGP_ID_OFFSET 10 + #if (defined(__AMDGCN_CUMODE__)) + #define HW_ID_CU_ID_SIZE 1 + #define HW_ID_CU_ID_OFFSET 8 + #endif #else #define HW_ID_CU_ID_SIZE 4 #define HW_ID_CU_ID_OFFSET 8 @@ -933,6 +905,10 @@ unsigned __smid(void) GETREG_IMMED(HW_ID_WGP_ID_SIZE - 1, HW_ID_WGP_ID_OFFSET, HW_ID)); unsigned sa_id = __builtin_amdgcn_s_getreg( GETREG_IMMED(HW_ID_SA_ID_SIZE - 1, HW_ID_SA_ID_OFFSET, HW_ID)); + #if (defined(__AMDGCN_CUMODE__)) + unsigned cu_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_CU_ID_SIZE - 1, HW_ID_CU_ID_OFFSET, HW_ID)); + #endif #else #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) unsigned xcc_id = __builtin_amdgcn_s_getreg( @@ -945,6 +921,9 @@ unsigned __smid(void) unsigned temp = se_id; temp = (temp << HW_ID_SA_ID_SIZE) | sa_id; temp = (temp << HW_ID_WGP_ID_SIZE) | wgp_id; + #if (defined(__AMDGCN_CUMODE__)) + temp = (temp << HW_ID_CU_ID_SIZE) | cu_id; + #endif return temp; //TODO : CU Mode impl #elif (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h b/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h index ef719f3713c6..d6e4d8186909 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h @@ -612,11 +612,17 @@ float atomicMin(float* addr, float val) { #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) return unsafeAtomicMin(addr, val); #else + typedef union u_hold { + float a; + unsigned int b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x80000000U == u.b; #if __has_builtin(__hip_atomic_load) && \ __has_builtin(__hip_atomic_compare_exchange_strong) float value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); bool done = false; - while (!done && value > val) { + while (!done && (value > val || (neg_zero && value == 0.0f))) { done = __hip_atomic_compare_exchange_strong(addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -625,7 +631,7 @@ float atomicMin(float* addr, float val) { unsigned int *uaddr = (unsigned int *)addr; unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); bool done = false; - while (!done && __uint_as_float(value) > val) { + while (!done && (__uint_as_float(value) > val || (neg_zero && __uint_as_float(value) == 0.0f))) { done = __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); } @@ -658,11 +664,17 @@ double atomicMin(double* addr, double val) { #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) return unsafeAtomicMin(addr, val); #else + typedef union u_hold { + double a; + unsigned long long b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x8000000000000000ULL == u.b; #if __has_builtin(__hip_atomic_load) && \ __has_builtin(__hip_atomic_compare_exchange_strong) double value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); bool done = false; - while (!done && value > val) { + while (!done && (value > val || (neg_zero && value == 0.0))) { done = __hip_atomic_compare_exchange_strong(addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -671,7 +683,8 @@ double atomicMin(double* addr, double val) { unsigned long long *uaddr = (unsigned long long *)addr; unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); bool done = false; - while (!done && __longlong_as_double(value) > val) { + while (!done && + (__longlong_as_double(value) > val || (neg_zero && __longlong_as_double(value) == 0.0))) { done = __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); } @@ -856,11 +869,17 @@ float atomicMax(float* addr, float val) { #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) return unsafeAtomicMax(addr, val); #else + typedef union u_hold { + float a; + unsigned int b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x80000000U == u.b; #if __has_builtin(__hip_atomic_load) && \ __has_builtin(__hip_atomic_compare_exchange_strong) float value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); bool done = false; - while (!done && value < val) { + while (!done && (value < val || (neg_zero && value == 0.0f))) { done = __hip_atomic_compare_exchange_strong(addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -869,7 +888,7 @@ float atomicMax(float* addr, float val) { unsigned int *uaddr = (unsigned int *)addr; unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); bool done = false; - while (!done && __uint_as_float(value) < val) { + while (!done && (__uint_as_float(value) < val || (neg_zero && __uint_as_float(value) == 0.0f))) { done = __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); } @@ -902,11 +921,17 @@ double atomicMax(double* addr, double val) { #if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) return unsafeAtomicMax(addr, val); #else + typedef union u_hold { + double a; + unsigned long long b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x8000000000000000ULL == u.b; #if __has_builtin(__hip_atomic_load) && \ __has_builtin(__hip_atomic_compare_exchange_strong) double value = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); bool done = false; - while (!done && value < val) { + while (!done && (value < val || (neg_zero && value == 0.0))) { done = __hip_atomic_compare_exchange_strong(addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -915,7 +940,8 @@ double atomicMax(double* addr, double val) { unsigned long long *uaddr = (unsigned long long *)addr; unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); bool done = false; - while (!done && __longlong_as_double(value) < val) { + while (!done && + (__longlong_as_double(value) < val || (neg_zero && __longlong_as_double(value) == 0.0))) { done = __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); } @@ -977,7 +1003,7 @@ unsigned int atomicDec(unsigned int* address, unsigned int val) #else return __builtin_amdgcn_atomic_dec32(address, val, __ATOMIC_RELAXED, "agent"); #endif // __gfx941__ - + } __device__ diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h b/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h index 204269a849c6..cfaa5412a3aa 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h @@ -1,7 +1,7 @@ /** * MIT License * - * Copyright (c) 2019 - 2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -81,6 +81,17 @@ * To use these functions, include the header file \p hip_bf16.h in your program. */ +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_RAW Bfloat16 Raw Struct + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_RAW Bfloat162 Raw Struct + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your program. + */ #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ @@ -93,13 +104,30 @@ #include "device_library_decls.h" // ocml conversion functions #include "math_fwd.h" // ocml device functions +#define __BF16_DEVICE__ __device__ #if defined(__HIPCC_RTC__) -#define __HOST_DEVICE__ __device__ static +#define __BF16_HOST_DEVICE__ __BF16_DEVICE__ #else #include #include #include -#define __HOST_DEVICE__ __host__ __device__ static inline +#define __BF16_HOST_DEVICE__ __host__ __BF16_DEVICE__ +#endif +#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline +#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline + +#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__) +// Enable with -mavx512vl -mavx512bf16 +#if defined(__MINGW64__) +#include +#else +#include +#endif +#define HIP_BF16_AVX512_OP 1 +static_assert(sizeof(__bf16) == sizeof(unsigned short), + "sizeof __bf16 should match sizeof unsigned short"); +#else +#define HIP_BF16_AVX512_OP 0 #endif #define HIPRT_ONE_BF16 __float2bfloat16(1.0f) @@ -118,72 +146,361 @@ static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); #endif static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes"); -/*! \brief Struct to represent a 16 bit brain floating point number. */ -struct __hip_bfloat16 { - unsigned short data; +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_RAW + * \brief represents raw bfloat16 type + */ +typedef struct __attribute__((aligned(2))) { + unsigned short x; +} __hip_bfloat16_raw; + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_RAW + * \brief represents raw bfloat16x2 vector type + */ +typedef struct __attribute__((aligned(4))) { + unsigned short x; + unsigned short y; +} __hip_bfloat162_raw; + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_STRUCT + * \ingroup HIP_INTRINSIC_BFLOAT16 + * \brief Struct to represent a 16 bit brain floating point number. + * @{ + */ +struct __attribute__((aligned(2))) __hip_bfloat16 { + private: + __BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) { +#if HIP_BF16_AVX512_OP + union { + unsigned short us; + __bf16 bf16; + } u = {val}; + return _mm_cvtsbh_ss(u.bf16); +#else + unsigned int uval = val << 16; + union { + unsigned int u32; + float fp32; + } u = {uval}; + return u.fp32; +#endif + } + __BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) { +#if HIP_BF16_AVX512_OP + union { + __bf16 bf16; + unsigned short us; + } u = {_mm_cvtness_sbh(f)}; + return u.us; +#else + union { + float fp32; + unsigned int u32; + } u = {f}; + if (~u.u32 & 0x7f800000) { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even + } else if (u.u32 & 0xffff) { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.u32 |= 0x10000; // Preserve signaling NaN + } + return static_cast(u.u32 >> 16); +#endif + } + + __BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) { + union { + float fp32; + unsigned int u32; + } u = {static_cast(d_in)}; + double d = u.fp32; + + // Round to odd + if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) { + u.u32--; + u.u32 |= 1; + } + + return float_2_bfloatraw(u.fp32); + } + + protected: + /*! \brief raw representation of bfloat16 */ + unsigned short __x; + + public: + // TODO: SWDEV-452411 + // Need to add constructor of __hip_bfloat16 from + // unsigned long long + // long long + // long + // unsigned long + // Casting directly to double might lead to double rounding. + + /*! \brief create __hip_bfloat16 from an unsigned int */ + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val) + : __x(double_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a int */ + __BF16_HOST_DEVICE__ __hip_bfloat16(int val) + : __x(double_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from an unsigned short */ + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val) + : __x(float_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a short */ + __BF16_HOST_DEVICE__ __hip_bfloat16(short val) + : __x(float_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a double */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x(double_2_bfloatraw(val)) {} + + /*! \brief create __hip_bfloat16 from a float */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x(float_2_bfloatraw(val)) {} + + /*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {} + + /*! \brief default constructor */ + __BF16_HOST_DEVICE__ __hip_bfloat16() = default; + + /*! \brief return a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const { return __hip_bfloat16_raw{__x}; } + + /*! \brief return a __hip_bfloat16_raw cv qualifier */ + __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const volatile { + return __hip_bfloat16_raw{__x}; + } + + /*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */ + __BF16_HOST_DEVICE__ operator bool() const { + auto val = bfloatraw_2_float(__x); + return val != 0.0f && val != -0.0f; + } + + /*! \brief return a casted char from underlying float val */ + __BF16_HOST_DEVICE__ operator char() const { return static_cast(bfloatraw_2_float(__x)); } + + /*! \brief return a float */ + __BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); } + + /*! \brief return a casted int casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator int() const { return static_cast(bfloatraw_2_float(__x)); } + + /*! \brief return a casted long casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator long() const { return static_cast(bfloatraw_2_float(__x)); } + + /*! \brief return a casted long long casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator long long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted short casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator short() const { return static_cast(bfloatraw_2_float(__x)); } + + /*! \brief return a casted signed char from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator signed char() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned char casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned char() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned int casted from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned int() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned long long from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned long long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned short from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned short() const { + return static_cast(bfloatraw_2_float(__x)); + } + + // TODO: SWDEV-452411 add operator which converts unsigned long long and long long to bfloat + + /*! \brief assign value from an unsigned int */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned int val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a int */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(int val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from an unsigned short */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned short val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a short int */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(short val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a double */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const double f) { + __x = float_2_bfloatraw(static_cast(f)); + return *this; + } + + /*! \brief assign value from a float */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const float f) { + __x = float_2_bfloatraw(f); + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw volatile */ + __BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=(const __hip_bfloat16_raw& hr) volatile { + __x = hr.x; + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw cv qualifier */ + __BF16_HOST_DEVICE__ volatile __hip_bfloat16& operator=( + const volatile __hip_bfloat16_raw& hr) volatile { + __x = hr.x; + return *this; + } }; +/**@}*/ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT + * \ingroup HIP_INTRINSIC_BFLOAT16 + * \brief Struct to represent a two 16 bit brain floating point number. + * @{ + */ +struct __attribute__((aligned(4))) __hip_bfloat162 { + public: + __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ + __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */ + + + public: + /*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r) + : x(__hip_bfloat16(__hip_bfloat16_raw{h2r.x})), + y(__hip_bfloat16(__hip_bfloat16_raw{h2r.y})) {} + + /*! \brief copy constructor of __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162& val) { + __hip_bfloat162_raw hr = val; + x = __hip_bfloat16_raw{hr.x}; + y = __hip_bfloat16_raw{hr.y}; + } + + /*! \brief create __hip_bfloat162 from two __hip_bfloat16 */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b) + : x(a), y(b) {} + + /*! \brief default constructor of __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162() = default; + + /*! \brief return a __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ operator __hip_bfloat162_raw() const { + __hip_bfloat16_raw l = x; + __hip_bfloat16_raw r = y; + return __hip_bfloat162_raw{l.x, r.x}; + } -/*! \brief Struct to represent two 16 bit brain floating point numbers. */ -struct __hip_bfloat162 { - __hip_bfloat16 x; - __hip_bfloat16 y; + /*! \brief return a float2 */ + __BF16_HOST_DEVICE__ operator float2() const { +#if HIP_BF16_AVX512_OP + union { + __hip_bfloat162_raw raw2; + __bf16 bf162[2]; + static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw)); + } u; + u.raw2 = *this; + __m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0}; + __m128 pf32 = _mm_cvtpbh_ps(pbf16); + float2 ret(pf32[0], pf32[1]); +#else + float2 ret(x, y); +#endif + return ret; + } + + /*! \brief assign value from __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) { + x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y}); + return *this; + } + + /*! \brief assign value from __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) { + __hip_bfloat162_raw hr = src; + x = __hip_bfloat16(__hip_bfloat16_raw{hr.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{hr.y}); + return *this; + } }; +/**@}*/ /** * \ingroup HIP_INTRINSIC_BFLOAT16_CONV * \brief Converts bfloat16 to float */ -__HOST_DEVICE__ inline float __bfloat162float(__hip_bfloat16 a) { - unsigned int uval = 0; - uval = a.data << 16; - union { - unsigned int u32; - float fp32; - } u = {uval}; - return u.fp32; +__BF16_HOST_DEVICE_STATIC__ float __bfloat162float(__hip_bfloat16 a) { + float ret = a; + return ret; } /** * \ingroup HIP_INTRINSIC_BFLOAT16_CONV * \brief Converts float to bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) { - __hip_bfloat16 ret; - union { - float fp32; - unsigned int u32; - } u = {f}; - if (~u.u32 & 0x7f800000) { - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even - } else if (u.u32 & 0xffff) { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - u.u32 |= 0x10000; // Preserve signaling NaN - } - - ret.data = (u.u32 >> 16); +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) { + __hip_bfloat16 ret{f}; return ret; } @@ -191,43 +508,51 @@ __HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) { * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts and moves bfloat162 to float2 */ -__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) { - return float2{__bfloat162float(a.x), __bfloat162float(a.y)}; +__BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) { + float2 ret = a; + return ret; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Moves bfloat16 value to bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { - return __hip_bfloat162{a, a}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { + return __hip_bfloat162(a, a); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer */ -__HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } +__BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) { + short ret = h; + return ret; +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer */ -__HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } +__BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { + unsigned short ret = h; + return ret; +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Convert double to __hip_bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) { - return __float2bfloat16((float)a); +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __double2bfloat16(const double a) { + __hip_bfloat16 ret{a}; + return ret; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Convert float2 to __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)}; } @@ -235,97 +560,117 @@ __HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Combine two __hip_bfloat16 to __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __hip_bfloat162{a, b}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __hip_bfloat162(a, b); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Returns high 16 bits of __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat16(__hip_bfloat16_raw{hr.y}); +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Returns high 16 bits of __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { - return __hip_bfloat162{a.y, a.y}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(__hip_bfloat16_raw{hr.y}, __hip_bfloat16_raw{hr.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result */ -__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } +__BF16_HOST_DEVICE_STATIC__ float __high2float(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y})); +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Extracts high 16 bits from each and combines them */ -__HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, - const __hip_bfloat162 b) { - return __hip_bfloat162{a.y, b.y}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hip_bfloat162_raw{hr_a.y, hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Returns low 16 bits of __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat16(hr.x); +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Returns low 16 bits of __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { - return __hip_bfloat162{a.x, a.x}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(hr.x, hr.x); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result */ -__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } +__BF16_HOST_DEVICE_STATIC__ float __low2float(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x})); +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Swaps both halves */ -__HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { - return __hip_bfloat162{a.y, a.x}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(__hip_bfloat162_raw{hr.y, hr.x}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Extracts low 16 bits from each and combines them */ -__HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{a.x, b.x}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hip_bfloat162_raw{hr_a.x, hr_b.x}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Reinterprets short int into a bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) { - return __hip_bfloat16{(unsigned short)a}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) { + return __hip_bfloat16(a); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Reinterprets unsigned short int into a bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { - return __hip_bfloat16{a}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { + return __hip_bfloat16(a); } - /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Adds two bfloat16 values */ -__HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); } @@ -333,7 +678,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Subtracts two bfloat16 values */ -__HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); } @@ -341,7 +686,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Divides two bfloat16 values */ -__HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); } @@ -349,8 +694,8 @@ __HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Performs FMA of given bfloat16 values */ -__device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b, - const __hip_bfloat16 c) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b, + const __hip_bfloat16 c) { return __float2bfloat16( __ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c))); } @@ -359,7 +704,7 @@ __device__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b, * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Multiplies two bfloat16 values */ -__HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); } @@ -367,85 +712,110 @@ __HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Negate a bfloat16 value */ -__HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { - auto ret = a; - ret.data ^= 0x8000; - return ret; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + hr.x ^= 0x8000; + return __hip_bfloat16(hr); } /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Returns absolute of a bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { - auto ret = a; - ret.data &= 0x7FFF; - return ret; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + hr.x &= 0x7FFF; + return __hip_bfloat16(hr); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Divides bfloat162 values */ -__HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)), - __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.x}) / + __bfloat162float(__hip_bfloat16_raw{hr_b.x})), + __float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.y}) / + __bfloat162float(__hip_bfloat16_raw{hr_b.y}))); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Returns absolute of a bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { - return __hip_bfloat162{__habs(a.x), __habs(a.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162(__habs(__hip_bfloat16_raw{hr_a.x}), __habs(__hip_bfloat16_raw{hr_a.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Adds two bfloat162 values */ -__HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hadd(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hadd(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Performs FMA of given bfloat162 values */ -__device__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b, - const __hip_bfloat162 c) { - return __hip_bfloat162{__hfma(a.x, b.x, c.x), __hfma(a.y, b.y, c.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b, + const __hip_bfloat162 c) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + __hip_bfloat162_raw hr_c = c; + return __hip_bfloat162( + __hfma(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}, __hip_bfloat16_raw{hr_c.x}), + __hfma(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}, __hip_bfloat16_raw{hr_c.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Multiplies two bfloat162 values */ -__HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hmul(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmul(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Converts a bfloat162 into negative */ -__HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { - return __hip_bfloat162{__hneg(a.x), __hneg(a.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162(__hneg(__hip_bfloat16_raw{hr_a.x}), __hneg(__hip_bfloat16_raw{hr_a.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Subtracts two bfloat162 values */ -__HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hsub(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hsub(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to multiply two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16& l, + const __hip_bfloat16& r) { return __hmul(l, r); } @@ -453,7 +823,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bf * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to multiply-assign two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) { l = __hmul(l, r); return l; } @@ -462,13 +832,14 @@ __HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to unary+ on a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; } /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to add two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16& l, + const __hip_bfloat16& r) { return __hadd(l, r); } @@ -476,13 +847,14 @@ __HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bf * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to negate a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); } /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to subtract two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16& l, + const __hip_bfloat16& r) { return __hsub(l, r); } @@ -490,7 +862,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bf * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to post increment a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { auto ret = l; l = __hadd(l, HIPRT_ONE_BF16); return ret; @@ -500,7 +872,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to pre increment a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator++(__hip_bfloat16& l) { l = __hadd(l, HIPRT_ONE_BF16); return l; } @@ -509,7 +881,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) { * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to post decrement a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { auto ret = l; l = __hsub(l, HIPRT_ONE_BF16); return ret; @@ -519,7 +891,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to pre decrement a __hip_bfloat16 number */ -__HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator--(__hip_bfloat16& l) { l = __hsub(l, HIPRT_ONE_BF16); return l; } @@ -528,7 +900,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) { * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to add-assign two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) { l = __hadd(l, r); return l; } @@ -537,7 +909,7 @@ __HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to subtract-assign two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) { l = __hsub(l, r); return l; } @@ -546,7 +918,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to divide two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16& l, + const __hip_bfloat16& r) { return __hdiv(l, r); } @@ -554,7 +927,7 @@ __HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bf * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Operator to divide-assign two __hip_bfloat16 numbers */ -__HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) { l = __hdiv(l, r); return l; } @@ -563,7 +936,8 @@ __HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to multiply two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator*(const __hip_bfloat162& l, + const __hip_bfloat162& r) { return __hmul2(l, r); } @@ -571,7 +945,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to multiply-assign two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator*=(__hip_bfloat162& l, + const __hip_bfloat162& r) { l = __hmul2(l, r); return l; } @@ -580,13 +955,14 @@ __HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bflo * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to unary+ on a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to add two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator+(const __hip_bfloat162& l, + const __hip_bfloat162& r) { return __hadd2(l, r); } @@ -594,13 +970,16 @@ __HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to negate a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); } +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { + return __hneg2(l); +} /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to subtract two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator-(const __hip_bfloat162& l, + const __hip_bfloat162& r) { return __hsub2(l, r); } @@ -608,7 +987,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to post increment a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { auto ret = l; l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); return ret; @@ -618,7 +997,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to pre increment a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator++(__hip_bfloat162& l) { l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); return l; } @@ -627,7 +1006,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) { * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to post decrement a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { auto ret = l; l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); return ret; @@ -637,7 +1016,7 @@ __HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to pre decrement a __hip_bfloat162 number */ -__HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator--(__hip_bfloat162& l) { l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); return l; } @@ -646,7 +1025,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) { * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to add-assign two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator+=(__hip_bfloat162& l, + const __hip_bfloat162& r) { l = __hadd2(l, r); return l; } @@ -655,7 +1035,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bflo * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to subtract-assign two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator-=(__hip_bfloat162& l, + const __hip_bfloat162& r) { l = __hsub2(l, r); return l; } @@ -664,7 +1045,8 @@ __HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bflo * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to divide two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator/(const __hip_bfloat162& l, + const __hip_bfloat162& r) { return __h2div(l, r); } @@ -672,7 +1054,8 @@ __HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Operator to divide-assign two __hip_bfloat162 numbers */ -__HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator/=(__hip_bfloat162& l, + const __hip_bfloat162& r) { l = __h2div(l, r); return l; } @@ -681,7 +1064,7 @@ __HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bflo * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values */ -__HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) == __bfloat162float(b); } @@ -689,7 +1072,7 @@ __HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered equal */ -__HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) < __bfloat162float(b)) && !(__bfloat162float(a) > __bfloat162float(b)); } @@ -698,7 +1081,7 @@ __HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - greater than */ -__HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) > __bfloat162float(b); } @@ -706,7 +1089,7 @@ __HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered greater than */ -__HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) <= __bfloat162float(b)); } @@ -714,7 +1097,7 @@ __HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - greater than equal */ -__HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) >= __bfloat162float(b); } @@ -722,7 +1105,7 @@ __HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered greater than equal */ -__HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) < __bfloat162float(b)); } @@ -730,7 +1113,7 @@ __HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - not equal */ -__HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) != __bfloat162float(b); } @@ -738,7 +1121,7 @@ __HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered not equal */ -__HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) == __bfloat162float(b)); } @@ -746,7 +1129,7 @@ __HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - return max */ -__HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { #if __HIP_DEVICE_COMPILE__ return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); #else @@ -758,7 +1141,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - return min */ -__HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { #if __HIP_DEVICE_COMPILE__ return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); #else @@ -770,7 +1153,7 @@ __HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - less than operator */ -__HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) < __bfloat162float(b); } @@ -778,7 +1161,7 @@ __HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered less than */ -__HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) >= __bfloat162float(b)); } @@ -786,7 +1169,7 @@ __HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - less than equal */ -__HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { return __bfloat162float(a) <= __bfloat162float(b); } @@ -794,7 +1177,7 @@ __HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Compare two bfloat162 values - unordered less than equal */ -__HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { +__BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { return !(__bfloat162float(a) > __bfloat162float(b)); } @@ -802,208 +1185,282 @@ __HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Checks if number is inf */ -__HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) { - unsigned short sign = a.data & 0x8000U; -#if __HIP_DEVICE_COMPILE__ - int res = __ocml_isinf_f32(__bfloat162float(a)); -#else - int res = std::isinf(__bfloat162float(a)) ? 1 : 0; -#endif - return (res == 0) ? res : ((sign != 0U) ? -res : res); +__BF16_HOST_DEVICE_STATIC__ int __hisinf(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + return !(~hr.x & 0x7f80) && !(hr.x & 0x7f); } /** * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Checks if number is nan */ -__HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) { -#if __HIP_DEVICE_COMPILE__ - return __ocml_isnan_f32(__bfloat162float(a)); -#else - return std::isnan(__bfloat162float(a)); -#endif +__BF16_HOST_DEVICE_STATIC__ bool __hisnan(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + return !(~hr.x & 0x7f80) && +(hr.x & 0x7f); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Checks if two numbers are equal */ -__HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __heq(a.x, b.x) && __heq(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Checks if two numbers are equal - unordered */ -__HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hequ(a.x, b.x) && __hequ(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hequ(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hequ(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a >= b */ -__HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hge(a.x, b.x) && __hge(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a >= b - unordered */ -__HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hgeu(a.x, b.x) && __hgeu(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgeu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgeu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a > b */ -__HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hgt(a.x, b.x) && __hgt(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a > b - unordered */ -__HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hgtu(a.x, b.x) && __hgtu(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgtu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgtu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a <= b */ -__HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hle(a.x, b.x) && __hle(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a <= b - unordered */ -__HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hleu(a.x, b.x) && __hleu(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hleu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hleu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a < b */ -__HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hlt(a.x, b.x) && __hlt(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a < b - unordered */ -__HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hltu(a.x, b.x) && __hltu(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hltu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hltu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a != b */ -__HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hne(a.x, b.x) && __hne(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.x}), + __hip_bfloat16(__hip_bfloat16_raw{hr_b.x})) && + __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.y}), __hip_bfloat16(__hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a != b */ -__HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hneu(a.x, b.x) && __hneu(a.y, b.y); +__BF16_HOST_DEVICE_STATIC__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hneu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) || + __hneu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ONE_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { - return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162{{__hisnan(__hip_bfloat16_raw{hr_a.x}) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, + {__hisnan(__hip_bfloat16_raw{hr_a.y}) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0 */ -__HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Returns max of two elements */ -__HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hmax(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmax(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Returns min of two elements */ -__HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hmin(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmin(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Checks for not equal to */ -__HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, - {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hne(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hne(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; } /** * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __heq(l, r); } @@ -1011,7 +1468,7 @@ __HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a not equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __hne(l, r); } @@ -1019,7 +1476,7 @@ __HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a less than on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __hlt(l, r); } @@ -1027,7 +1484,7 @@ __HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __hle(l, r); } @@ -1035,7 +1492,7 @@ __HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a greater than on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __hgt(l, r); } @@ -1043,7 +1500,7 @@ __HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) { +__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) { return __hge(l, r); } @@ -1051,55 +1508,60 @@ __HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __heq(l.x, r.x) && __heq(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) { + float2 ret = __heq2(l, r); + return ret.x != 0.0f && ret.y != 0.0f; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Operator to perform a not equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __hne(l.x, r.x) || __hne(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) { + return !(l == r); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Operator to perform a less than on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __hlt(l.x, r.x) && __hlt(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) { + float2 fl = l, fr = r; + return fl.x < fr.x && fl.x < fr.y; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __hle(l.x, r.x) && __hle(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) { + float2 fl = l, fr = r; + return fl.x <= fr.x && fl.x <= fr.y; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_COMP * \brief Operator to perform a greater than on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __hgt(l.x, r.x) && __hgt(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) { + float2 fl = l, fr = r; + return fl.x > fr.x && fl.x > fr.y; } /** * \ingroup HIP_INTRINSIC_BFLOAT16_COMP * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers */ -__HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) { - return __hge(l.x, r.x) && __hge(l.y, r.y); +__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) { + float2 fl = l, fr = r; + return fl.x >= fr.x && fl.x >= fr.y; } /** * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate ceil of bfloat16 */ -__device__ __hip_bfloat16 hceil(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hceil(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_ceil_f32(__bfloat162float(h))); } @@ -1107,7 +1569,7 @@ __device__ __hip_bfloat16 hceil(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate cosine of bfloat16 */ -__device__ __hip_bfloat16 hcos(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hcos(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_cos_f32(__bfloat162float(h))); } @@ -1115,7 +1577,7 @@ __device__ __hip_bfloat16 hcos(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate exponential of bfloat16 */ -__device__ __hip_bfloat16 hexp(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_exp_f32(__bfloat162float(h))); } @@ -1123,7 +1585,7 @@ __device__ __hip_bfloat16 hexp(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate exponential 10 of bfloat16 */ -__device__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_exp10_f32(__bfloat162float(h))); } @@ -1131,7 +1593,7 @@ __device__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate exponential 2 of bfloat16 */ -__device__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_exp2_f32(__bfloat162float(h))); } @@ -1139,7 +1601,7 @@ __device__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate floor of bfloat16 */ -__device__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_floor_f32(__bfloat162float(h))); } @@ -1147,7 +1609,7 @@ __device__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate natural log of bfloat16 */ -__device__ __hip_bfloat16 hlog(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_log_f32(__bfloat162float(h))); } @@ -1155,7 +1617,7 @@ __device__ __hip_bfloat16 hlog(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate log 10 of bfloat16 */ -__device__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_log10_f32(__bfloat162float(h))); } @@ -1163,7 +1625,7 @@ __device__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate log 2 of bfloat16 */ -__device__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_log2_f32(__bfloat162float(h))); } @@ -1171,7 +1633,7 @@ __device__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate reciprocal */ -__device__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) { return __float2bfloat16(1.0f / (__bfloat162float(h))); } @@ -1179,7 +1641,7 @@ __device__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Round to nearest int */ -__device__ __hip_bfloat16 hrint(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrint(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_rint_f32(__bfloat162float(h))); } @@ -1187,7 +1649,7 @@ __device__ __hip_bfloat16 hrint(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Reciprocal square root */ -__device__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_rsqrt_f32(__bfloat162float(h))); } @@ -1195,7 +1657,7 @@ __device__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate sin of bfloat16 */ -__device__ __hip_bfloat16 hsin(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_sin_f32(__bfloat162float(h))); } @@ -1203,7 +1665,7 @@ __device__ __hip_bfloat16 hsin(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate sqrt of bfloat16 */ -__device__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h))); } @@ -1211,7 +1673,7 @@ __device__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate truncate of bfloat16 */ -__device__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) { +__BF16_DEVICE_STATIC__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) { return __float2bfloat16(__ocml_trunc_f32(__bfloat162float(h))); } @@ -1219,119 +1681,134 @@ __device__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) { * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate ceil of bfloat162 */ -__device__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) { - return __hip_bfloat162{hceil(h.x), hceil(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hceil(__hip_bfloat16_raw{hr.x}), hceil(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate cosine of bfloat162 */ -__device__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) { - return __hip_bfloat162{hcos(h.x), hcos(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hcos(__hip_bfloat16_raw{hr.x}), hcos(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate exponential of bfloat162 */ -__device__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) { - return __hip_bfloat162{hexp(h.x), hexp(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp(__hip_bfloat16_raw{hr.x}), hexp(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate exponential 10 of bfloat162 */ -__device__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) { - return __hip_bfloat162{hexp10(h.x), hexp10(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp10(__hip_bfloat16_raw{hr.x}), hexp10(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate exponential 2 of bfloat162 */ -__device__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) { - return __hip_bfloat162{hexp2(h.x), hexp2(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp2(__hip_bfloat16_raw{hr.x}), hexp2(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate floor of bfloat162 */ -__device__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) { - return __hip_bfloat162{hfloor(h.x), hfloor(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hfloor(__hip_bfloat16_raw{hr.x}), hfloor(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate natural log of bfloat162 */ -__device__ __hip_bfloat162 h2log(const __hip_bfloat162 h) { - return __hip_bfloat162{hlog(h.x), hlog(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog(__hip_bfloat16_raw{hr.x}), hlog(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate log 10 of bfloat162 */ -__device__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) { - return __hip_bfloat162{hlog10(h.x), hlog10(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog10(__hip_bfloat16_raw{hr.x}), hlog10(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate log 2 of bfloat162 */ -__device__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) { - return __hip_bfloat162{hlog2(h.x), hlog2(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog2(__hip_bfloat16_raw{hr.x}), hlog2(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate vector reciprocal */ -__device__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) { - return __hip_bfloat162{hrcp(h.x), hrcp(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrcp(__hip_bfloat16_raw{hr.x}), hrcp(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate vector round to nearest int */ -__device__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) { - return __hip_bfloat162{hrint(h.x), hrint(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrint(__hip_bfloat16_raw{hr.x}), hrint(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate vector reciprocal square root */ -__device__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) { - return __hip_bfloat162{hrsqrt(h.x), hrsqrt(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrsqrt(__hip_bfloat16_raw{hr.x}), hrsqrt(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate sin of bfloat162 */ -__device__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) { - return __hip_bfloat162{hsin(h.x), hsin(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hsin(__hip_bfloat16_raw{hr.x}), hsin(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate sqrt of bfloat162 */ -__device__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) { - return __hip_bfloat162{hsqrt(h.x), hsqrt(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hsqrt(__hip_bfloat16_raw{hr.x}), hsqrt(__hip_bfloat16_raw{hr.y})); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Calculate truncate of bfloat162 */ -__device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { - return __hip_bfloat162{htrunc(h.x), htrunc(h.y)}; +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(htrunc(__hip_bfloat16_raw{hr.x}), htrunc(__hip_bfloat16_raw{hr.y})); } #endif diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h b/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h index 8b1a0c067db1..c01039a7e1cc 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h @@ -216,12 +216,18 @@ class thread_block : public thread_group { if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) { __hip_assert(false && "invalid tile size"); } + + auto block_size = size(); + auto rank = thread_rank(); + auto partitions = (block_size + tile_size - 1) / tile_size; + auto tail = (partitions * tile_size) - block_size; + auto partition_size = tile_size - tail * (rank >= (partitions - 1) * tile_size); + thread_group tiledGroup = thread_group(internal::cg_tiled_group, partition_size); - thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size); tiledGroup.coalesced_info.tiled_info.size = tile_size; tiledGroup.coalesced_info.tiled_info.is_tiled = true; - tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size; - tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size; + tiledGroup.coalesced_info.tiled_info.meta_group_rank = rank / tile_size; + tiledGroup.coalesced_info.tiled_info.meta_group_size = partitions; return tiledGroup; } diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h b/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h new file mode 100644 index 000000000000..e54c70241701 --- /dev/null +++ b/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h @@ -0,0 +1,1391 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * \file + * \brief amd_hip_fp8.h header, for AMD fp8 data types + */ + +#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ +#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ + +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__ +#define HIP_FP8_CVT_FAST_PATH 1 +#else +#define HIP_FP8_CVT_FAST_PATH 0 +#endif + +#if !defined(__HIPCC_RTC__) +#include +#include + +#include "host_defines.h" // __hip_internal:: +#include "amd_hip_vector_types.h" // float2 etc +#include "amd_hip_fp16.h" // __half_raw +#include "amd_hip_bf16.h" // bf16 +#include "math_fwd.h" // ocml device functions +#endif // !defined(__HIPCC_RTC__) + +#if defined(__HIPCC_RTC__) +#define __FP8_HOST_DEVICE__ __device__ +#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static +#else +#define __FP8_HOST_DEVICE__ __host__ __device__ +#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline +#endif // __HIPCC_RTC__ + +#if !defined(__HIPCC_RTC__) +static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); +#endif +static_assert(sizeof(unsigned char) == 1); +static_assert(sizeof(unsigned short int) == 2); +static_assert(sizeof(unsigned int) == 4); + +/** + * \brief Describes FP8 interpretation + */ +enum __hip_fp8_interpretation_t { + __HIP_E4M3_FNUZ = 0, /**< Standard FP8 */ + __HIP_E5M2_FNUZ = 1, /**< BF8 */ +}; + +/** + * \brief Describes saturation behavior + */ +enum __hip_saturation_t { + __HIP_NOSAT = 0, /**< No saturation */ + __HIP_SATFINITE = 1, /**< Saturate to finite */ +}; + +/** \typedef __hip_fp8_storage_t + * + * \brief type to store single fp8 number + */ +typedef unsigned char __hip_fp8_storage_t; + + +/** \typedef __hip_fp8x2_storage_t + * + * \brief type to store two fp8 numbers + */ +typedef unsigned short int __hip_fp8x2_storage_t; + + +/** \typedef __hip_fp8x4_storage_t + * + * \brief type to store four fp8 numbers + */ +typedef unsigned int __hip_fp8x4_storage_t; + +namespace internal { +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 +// This has been modified to add double types conversion as well +template +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false, + bool stoch = false, + unsigned int rng = 0) { + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8"); + + const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); + unsigned long long x; + + if (sizeof(T) == 8) + x = reinterpret_cast(_x); + else if (sizeof(T) == 4) + x = reinterpret_cast(_x); + else + x = reinterpret_cast(_x); + + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + + if (sizeof(T) == 8) { + head = x & 0xFFF0000000000000ull; + mantissa = x & 0xFFFFFFFFFFFFFull; + exponent = (head >> 52) & 0x7FF; + sign = head >> 63; + bias = 1023; + } else if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 8) { + if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) return 0x80; + } else if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) return 0x80; + } else { + if ((x & 0x7C00) == 0x7C00) return 0x80; + } + } else { + if (sizeof(T) == 8) { + if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) + return signed_inf + (mantissa != 0 ? 1 : 0); + } else if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) return signed_inf + (mantissa != 0 ? 1 : 0); + } else { + if ((x & 0x7C00) == 0x7C00) return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In +this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. +For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 +actual exponent is -7, it is actually larger due to the implict 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift +right as shift right could rip off some residual part and make something not midpoint look like +midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but +after shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) + mantissa >>= exponent_diff; + else if (exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1ull << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1ull << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 +// This has been modified to handle double types as well +template +__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we) { + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, "only half, float and double are supported"); + + constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52); + + T fInf, fNegInf, fNaN, fNeg0; + if (is_half) { + const unsigned short int ihInf = 0x7C00; + const unsigned short int ihNegInf = 0xFC00; + const unsigned short int ihNaN = 0x7C01; + const unsigned short int ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else if (is_float) { + const unsigned int ifInf = 0x7F800000; + const unsigned int ifNegInf = 0xFF800000; + const unsigned int ifNaN = 0x7F800001; + const unsigned int ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } else if (is_double) { + const unsigned long long ifInf = 0x7FF0000000000000ull; + const unsigned long long ifNegInf = 0xFFF0000000000000ull; + const unsigned long long ifNaN = 0x7FF0000000000001ull; + const unsigned long long ifNeg0 = 0x8000000000000000ull; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) return fNaN; + } else { + if (x == 0x80) return fNeg0; + if (exponent == ((1 << we) - 1)) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + + typename __hip_internal::conditional< + sizeof(T) == 2, unsigned short int, + typename __hip_internal::conditional::type>::type retval; + + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { +#if __HIP_DEVICE_COMPILE__ + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __clz(mantissa) - (32 - wm); +#else + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); +#endif + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else if (sizeof(T) == 4) + retval = (sign << 31) | (exponent << 23) | mantissa; + else + retval = (sign << 63) | (static_cast(exponent) << 52) | mantissa; + return reinterpret_cast(retval); +} + +#if HIP_FP8_CVT_FAST_PATH +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 +template +static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate, + __hip_fp8_interpretation_t interpret, + unsigned int rng = 0) { + __hip_fp8_storage_t i8data; + union { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if (saturate) { + if (interpret == __HIP_E4M3_FNUZ) { + if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } else { + if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if (stochastic_rounding) { + ival = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } else { // RNE CVT + ival = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} + +static __device__ __hip_fp8x2_storage_t +cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) { + union { + static_assert(sizeof(float2) == sizeof(unsigned int[2])); + static_assert(sizeof(float2) == sizeof(unsigned short[4])); + float2 fval; + unsigned int i32val[2]; + unsigned short i16val[4]; + } f2val; + + f2val.fval = v; + + if (saturate) { /// propagate NAN/INF, no clipping + if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) { + f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0); + } + if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) { + f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0); + } + } + + f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false); + + return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]); +} + +static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v, + __hip_fp8_interpretation_t interpret) { + union { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + float fval = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0) + : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); + return fval; +} + +static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v, + __hip_fp8_interpretation_t interpret) { + union { + unsigned int i32val; + unsigned short i16val[2]; + } val; + val.i16val[0] = v; + + auto f2 = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false) + : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false); + return float2{f2[0], f2[1]}; +} +#endif // HIP_FP8_CVT_FAST_PATH + +/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned. +Inf are not supported. This gives us one additional number to represent. +NaN are represented by 1-0000-000 or 1-00000-00 */ +__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) { + return static_cast(a) == 0x80; +} +} // namespace internal + +/** + * \brief convert float to @p __hip_fp8_storage_t + * + * \param f float number + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( + const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f8_from_f32(f, sat == __HIP_SATFINITE, type); +#else // HIP_FP8_CVT_FAST_PATH + int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); +#endif // HIP_FP8_CVT_FAST_PATH +} + +/** + * \brief convert float2 to @p __hip_fp8x2_storage_t + * + * \param f2 float2 number + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( + const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type); +#else + return static_cast<__hip_fp8x2_storage_t>( + static_cast(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 | + static_cast(__hip_cvt_float_to_fp8(f2.x, sat, type))); +#endif +} + +/** + * \brief convert double to @p __hip_fp8_storage_t + * + * \param d double val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( + const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { + int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); +} + +/** + * \brief convert double2 to @p __hip_fp8x2_storage_t + * + * \param d2 double2 val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2( + const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { + return static_cast<__hip_fp8x2_storage_t>( + static_cast(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 | + static_cast(__hip_cvt_double_to_fp8(d2.x, sat, type))); +} + +/** + * \brief convert __hip_bfloat16_raw to @p __hip_fp8_storage_t + * + * \param hr __hip_bfloat16_raw val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + float fval = __hip_bfloat16(hr); + return __hip_cvt_float_to_fp8(fval, sat, type); +} + +/** + * \brief convert double2 to @p __hip_fp8x2_storage_t + * + * \param hr __hip_bfloat162_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + float2 f2 = __hip_bfloat162(hr); + return __hip_cvt_float2_to_fp8x2(f2, sat, type); +} + +/** + * \brief convert @p __hip_fp8_storage_t to __half_raw + * + * \param x __hip_fp8_storage_t val + * \param type interpretation of fp8 + * \return __half_raw + */ +__FP8_HOST_DEVICE_STATIC__ __half_raw +__hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type) { + unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)}; +} + +/** + * \brief convert @p __hip_fp8x2_storage_t to __half2_raw + * + * \param x __hip_fp8x2_storage_t val + * \param type interpretation of fp8 + * \return __half2_raw + */ +__FP8_HOST_DEVICE_STATIC__ __half2_raw +__hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type) { + __half2 ret(static_cast<__half>( + __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)), + static_cast<__half>( + __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type))); + return static_cast<__half2_raw>(ret); +} + +/** + * \brief convert __half_raw to @p __hip_fp8_storage_t + * + * \param x __half_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8( + const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { + return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type); +} + +/** + * \brief convert __half2_raw to @p __hip_fp8x2_storage_t + * + * \param x __half2_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( + const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { + return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type); +} + +/** + * \brief struct representing single fp8 number with e4m3 interpretation + * + */ +struct __hip_fp8_e4m3_fnuz { + __hip_fp8_storage_t __x; //! raw storage of fp8 number + constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; + constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ; + constexpr static unsigned int __we = 4; + constexpr static unsigned int __wm = 3; + + // TODO: SWDEV-452411 + // Add cast from unsigned long long, long long to fp8 + + /*! create fp8 e4m3 from long */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from short int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from unsigned long */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from unsigned int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from unsigned short */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from double */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f) + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + + /*! create fp8 e4m3 from float */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f) + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + + /*! create fp8 e4m3 from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f) + : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from __half */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f) + : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, + __default_interpret)) {} + + /*! default construct fp8 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default; + + /*! convert fp8 e4m3 to __half */ + __FP8_HOST_DEVICE__ operator __half() const { + return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); + } + + /*! convert fp8 e4m3 to __hip_bfloat16 */ + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + float f = *this; + return __hip_bfloat16(f); + } + + /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ + __FP8_HOST_DEVICE__ operator bool() const { + // it can be 0x00 (+0.0) since 0x80 will be nan + return !(static_cast(__x) == 0); + } + + /*! convert fp8 e4m3 to char, clamp number to CHAR_MIN/CHAR_MAX if its out of range */ + __FP8_HOST_DEVICE__ operator char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + auto fval = internal::cast_from_f8(__x, __wm, __we); + auto llval = static_cast(fval); + if (llval <= CHAR_MIN) { + return CHAR_MIN; + } else if (llval >= CHAR_MAX) { + return CHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to double */ + __FP8_HOST_DEVICE__ operator double() const { + return internal::cast_from_f8(__x, __wm, __we); + } + + /*! convert fp8 e4m3 to float */ + __FP8_HOST_DEVICE__ operator float() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32_from_f8(__x, __default_interpret); +#else + return internal::cast_from_f8(__x, __wm, __we); +#endif + } + + /*! convert fp8 e4m3 to int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to short int, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SHRT_MIN) { + return SHRT_MIN; + } else if (llval >= SHRT_MAX) { + return SHRT_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to signed char, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator signed char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SCHAR_MIN) { + return SCHAR_MIN; + } else if (llval >= SCHAR_MAX) { + return SCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned char, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } else if (llval >= UCHAR_MAX) { + return UCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long long int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned short, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } +}; + +/** + * \brief struct representing two fp8 numbers with e4m3 interpretation + * + */ +struct __hip_fp8x2_e4m3_fnuz { + __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ; + static constexpr unsigned int __we = 4; + static constexpr unsigned int __wm = 3; + + /*! create fp8x2 e4m3 type from double2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val) + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e4m3 type from float2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val) + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e4m3 type from __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val) + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e4m3 type from __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val) + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! Default construct of fp8x2 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default; + + /*! convert fp8x2 e4m3 to __half2 */ + __FP8_HOST_DEVICE__ operator __half2() const { + return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); + } + + /*! convert fp8x2 e4m3 to float2 */ + __FP8_HOST_DEVICE__ operator float2() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); +#else + return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), + __wm, __we), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), + __wm, __we)); +#endif + } +}; + +/** + * \brief struct representing four fp8 numbers with e4m3 interpretation + * + */ +struct __hip_fp8x4_e4m3_fnuz { + __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ; + static constexpr unsigned int __we = 4; + static constexpr unsigned int __wm = 3; + + /*! create fp8x4 e4m3 type from double4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val) + : __x{reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_double_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))} {} + + /*! create fp8x4 e4m3 type from float4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val) + : __x{reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_float_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))} {} + + /*! create fp8x4 e4m3 type from two __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast( + __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | + reinterpret_cast( + __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! create fp8x4 e4m3 type from two __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! Default construct fp8x4 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default; + + /*! convert fp8x4 e4m3 to float4 */ + __FP8_HOST_DEVICE__ operator float4() const { + auto x = __x; // bypass const + auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E + auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); +#if HIP_FP8_CVT_FAST_PATH + float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret); + float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret); +#else + float2 high = float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we)); + float2 low = float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we)); +#endif + return float4(low.x, low.y, high.x, high.y); + } +}; + +/** + * \brief struct representing one fp8 number with e5m2 interpretation + * + */ +struct __hip_fp8_e5m2_fnuz { + __hip_fp8_storage_t __x; //! raw storage of one fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + + // TODO: SWDEV-452411 + // Add cast from unsigned long long, long long to fp8 + + /*! create fp8 e5m2 type from long */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from short int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from unsigned long */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from unsigned int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from unsigned short */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from double */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f) + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + + /*! create fp8 e5m2 type from float */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f) + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + + /*! create fp8 e5m2 type from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f) + : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f) + : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, + __default_interpret)) {} + + /*! default construct fp8 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default; + + /*! convert fp8 e5m2 to float */ + __FP8_HOST_DEVICE__ operator float() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32_from_f8(__x, __default_interpret); +#else + return internal::cast_from_f8(__x, __wm, __we); +#endif + } + + /*! convert fp8 e5m2 to __half */ + __FP8_HOST_DEVICE__ operator __half() const { + return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); + } + + /*! convert fp8 e5m2 to __hip_bfloat16 */ + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + float f = *this; + return __hip_bfloat16(f); + } + + /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ + __FP8_HOST_DEVICE__ operator bool() const { + // it can be 0x00 (+0.0) since 0x80 will be nan + return !(static_cast(__x) == 0); + } + + /*! convert fp8 e5m2 to char, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= CHAR_MIN) { + return CHAR_MIN; + } else if (llval >= CHAR_MAX) { + return CHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to double */ + __FP8_HOST_DEVICE__ operator double() const { + return internal::cast_from_f8(__x, __wm, __we); + } + + /*! convert fp8 e5m2 to int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to short, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SHRT_MIN) { + return SHRT_MIN; + } else if (llval >= SHRT_MAX) { + return SHRT_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to signed char, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator signed char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SCHAR_MIN) { + return SCHAR_MIN; + } else if (llval >= SCHAR_MAX) { + return SCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned char, clamp out of bound values, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } else if (llval >= UCHAR_MAX) { + return UCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned short, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } +}; + +/** + * \brief struct representing two fp8 numbers with e5m2 interpretation + * + */ +struct __hip_fp8x2_e5m2_fnuz { + __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + /*! create fp8x2 e5m2 type from double2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val) + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e5m2 type from float2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val) + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e5m2 type from __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val) + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! create fp8x2 e5m2 type from __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val) + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + + /*! default construct fp8x2 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default; + + /*! convert fp8x2 e5m2 to __half2 */ + __FP8_HOST_DEVICE__ operator __half2() const { + return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); + } + + /*! convert fp8x2 e5m2 to float2 */ + __FP8_HOST_DEVICE__ operator float2() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); +#else + return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), + __wm, __we), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), + __wm, __we)); +#endif + } +}; + +/** + * \brief struct representing four fp8 numbers with e5m2 interpretation + * + */ +struct __hip_fp8x4_e5m2_fnuz { + __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + /*! create fp8x4 e5m2 type from double4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_double_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))) {} + + /*! create fp8x4 e5m2 type from float4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_float_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))) {} + + /*! create fp8x4 e5m2 type from two __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast( + __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | + reinterpret_cast( + __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! create fp8x4 e5m2 type from two __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>( + static_cast(reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /* default construct fp8x4 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default; + + /*! convert fp8x4 e5m2 to float4 */ + __FP8_HOST_DEVICE__ operator float4() const { + auto x = __x; // bypass const + auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E + auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); +#if HIP_FP8_CVT_FAST_PATH + float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret); + float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret); +#else + float2 high = float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we)); + float2 low = float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we)); +#endif + return float4(low.x, low.y, high.x, high.y); + } +}; + +#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h b/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h index e5b6dc3a359c..740e37a6db47 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h @@ -50,7 +50,7 @@ typedef enum hipGLDeviceList { typedef unsigned int GLuint; /** GLenum as uint.*/ typedef unsigned int GLenum; -/* +/** * @} */ @@ -99,10 +99,10 @@ hipError_t hipGraphicsGLRegisterBuffer(hipGraphicsResource** resource, GLuint bu */ hipError_t hipGraphicsGLRegisterImage(hipGraphicsResource** resource, GLuint image, GLenum target, unsigned int flags); -/* +/** * @} */ #if defined(__cplusplus) } #endif /* __cplusplus */ -#endif /* HIP_INCLUDE_AMD_HIP_GL_INTEROP_H */ +#endif /* HIP_INCLUDE_AMD_HIP_GL_INTEROP_H */ \ No newline at end of file diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h b/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h index 559ab20b3399..98f8896cd91d 100644 --- a/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h +++ b/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h @@ -75,6 +75,50 @@ __device__ static inline int __hip_move_dpp_N(int src) { static constexpr int warpSize = __AMDGCN_WAVEFRONT_SIZE; +// warp vote function __all __any __ballot +__device__ +inline +int __all(int predicate) { + return __ockl_wfall_i32(predicate); +} + +__device__ +inline +int __any(int predicate) { + return __ockl_wfany_i32(predicate); +} + +// XXX from llvm/include/llvm/IR/InstrTypes.h +#define ICMP_NE 33 + +__device__ +inline +unsigned long long int __ballot(int predicate) { + return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); +} + +__device__ +inline +unsigned long long int __ballot64(int predicate) { + return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); +} + +// See amd_warp_sync_functions.h for an explanation of this preprocessor flag. +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS +// Since threads in a wave do not make independent progress, __activemask() +// always returns the exact active mask, i.e, all active threads in the wave. +__device__ +inline +unsigned long long __activemask() { + return __ballot(true); +} +#endif // HIP_ENABLE_WARP_SYNC_BUILTINS + +__device__ static inline unsigned int __lane_id() { + return __builtin_amdgcn_mbcnt_hi( + -1, __builtin_amdgcn_mbcnt_lo(-1, 0)); +} + __device__ inline int __shfl(int var, int src_lane, int width = warpSize) { diff --git a/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h b/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h new file mode 100644 index 000000000000..8ef0b2e1d73e --- /dev/null +++ b/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h @@ -0,0 +1,288 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +// Warp sync builtins (with explicit mask argument) introduced in ROCm 6.2 as a +// preview to allow end-users to adapt to the new interface involving 64-bit +// masks. These are disabled by default, and can be enabled by setting the macro +// below. The builtins will be enabled unconditionally in ROCm 6.3. +// +// This arrangement also applies to the __activemask() builtin defined in +// amd_warp_functions.h. +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS + +#if !defined(__HIPCC_RTC__) +#include "amd_warp_functions.h" +#include "hip_assert.h" +#endif + +template +__device__ inline +T __hip_readfirstlane(T val) { + // In theory, behaviour is undefined when reading from a union member other + // than the member that was last assigned to, but it works in practice because + // we rely on the compiler to do the reasonable thing. + union { + unsigned long long l; + T d; + } u; + u.d = val; + // NOTE: The builtin returns int, so we first cast it to unsigned int and only + // then extend it to 64 bits. + unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l); + unsigned long long upper = + (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32); + u.l = (upper << 32) | lower; + return u.d; +} + +// When compiling for wave32 mode, ignore the upper half of the 64-bit mask. +#define __hip_adjust_mask_for_wave32(MASK) \ + do { \ + if (warpSize == 32) MASK &= 0xFFFFFFFF; \ + } while (0) + +// We use a macro to expand each builtin into a waterfall that implements the +// mask semantics: +// +// 1. The mask argument may be divergent. +// 2. Each active thread must have its own bit set in its own mask value. +// 3. For a given mask value, all threads that are mentioned in the mask must +// execute the same static instance of the builtin with the same mask. +// 4. The union of all mask values supplied at a static instance must be equal +// to the activemask at the program point. +// +// Thus, the mask argument partitions the set of currently active threads in the +// wave into disjoint subsets that cover all active threads. +// +// Implementation notes: +// --------------------- +// +// We implement this as a waterfall loop that executes the builtin for each +// subset separately. The return value is a divergent value across the active +// threads. The value for inactive threads is defined by each builtin +// separately. +// +// As long as every mask value is non-zero, we don't need to check if a lane +// specifies itself in the mask; that is done by the later assertion where all +// chosen lanes must be in the chosen mask. + +#define __hip_check_mask(MASK) \ + do { \ + __hip_assert(MASK && "mask must be non-zero"); \ + bool done = false; \ + while (__any(!done)) { \ + if (!done) { \ + auto chosen_mask = __hip_readfirstlane(MASK); \ + if (MASK == chosen_mask) { \ + __hip_assert(MASK == __ballot(true) && \ + "all threads specified in the mask" \ + " must execute the same operation with the same mask"); \ + done = true; \ + } \ + } \ + } \ + } while(0) + +#define __hip_do_sync(RETVAL, FUNC, MASK, ...) \ + do { \ + __hip_assert(MASK && "mask must be non-zero"); \ + bool done = false; \ + while (__any(!done)) { \ + if (!done) { \ + auto chosen_mask = __hip_readfirstlane(MASK); \ + if (MASK == chosen_mask) { \ + __hip_assert(MASK == __ballot(true) && \ + "all threads specified in the mask" \ + " must execute the same operation with the same mask"); \ + RETVAL = FUNC(__VA_ARGS__); \ + done = true; \ + } \ + } \ + } \ + } while(0) + +// __all_sync, __any_sync, __ballot_sync + +template +__device__ inline +unsigned long long __ballot_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __ballot(predicate) & mask; +} + +template +__device__ inline +int __all_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + return __ballot_sync(mask, predicate) == mask; +} + +template +__device__ inline +int __any_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + return __ballot_sync(mask, predicate) != 0; +} + +// __match_any, __match_all and sync variants + +template +__device__ inline +unsigned long long __match_any(T value) { + static_assert( + (__hip_internal::is_integral::value || __hip_internal::is_floating_point::value) && + (sizeof(T) == 4 || sizeof(T) == 8), + "T can be int, unsigned int, long, unsigned long, long long, unsigned " + "long long, float or double."); + bool done = false; + unsigned long long retval = 0; + + while (__any(!done)) { + if (!done) { + T chosen = __hip_readfirstlane(value); + if (chosen == value) { + retval = __activemask(); + done = true; + } + } + } + + return retval; +} + +template +__device__ inline +unsigned long long __match_any_sync(MaskT mask, T value) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __match_any(value) & mask; +} + +template +__device__ inline +unsigned long long __match_all(T value, int* pred) { + static_assert( + (__hip_internal::is_integral::value || __hip_internal::is_floating_point::value) && + (sizeof(T) == 4 || sizeof(T) == 8), + "T can be int, unsigned int, long, unsigned long, long long, unsigned " + "long long, float or double."); + T first = __hip_readfirstlane(value); + if (__all(first == value)) { + *pred = true; + return __activemask(); + } else { + *pred = false; + return 0; + } +} + +template +__device__ inline +unsigned long long __match_all_sync(MaskT mask, T value, int* pred) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + MaskT retval = 0; + __hip_adjust_mask_for_wave32(mask); + __hip_do_sync(retval, __match_all, mask, value, pred); + return retval; +} + +// various variants of shfl + +template +__device__ inline +T __shfl_sync(MaskT mask, T var, int srcLane, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl(var, srcLane, width); +} + +template +__device__ inline +T __shfl_up_sync(MaskT mask, T var, unsigned int delta, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_up(var, delta, width); +} + +template +__device__ inline +T __shfl_down_sync(MaskT mask, T var, unsigned int delta, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_down(var, delta, width); +} + +template +__device__ inline +T __shfl_xor_sync(MaskT mask, T var, int laneMask, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_xor(var, laneMask, width); +} + +#undef __hip_do_sync +#undef __hip_check_mask +#undef __hip_adjust_mask_for_wave32 + +#endif // HIP_ENABLE_WARP_SYNC_BUILTINS diff --git a/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp b/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp index 2152d519ebe1..768c62e09857 100644 --- a/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp +++ b/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp @@ -1,5 +1,5 @@ /* - Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2023 - 2024 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,46 @@ #include +// Define some version macros for the API table. Use similar naming conventions to HSA-runtime +// (MAJOR and STEP versions). Three groups at this time: +// +// (A) HIP_API_TABLE_* defines for versioning for API table structure +// (B) HIP_RUNTIME_API_TABLE_* defines for versioning the HipDispatchTable struct +// (C) HIP_COMPILER_API_TABLE_* defines for versioning the HipCompilerDispatchTable struct +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IMPORTANT !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// 1. When new functions are added to the API table, always add the new function pointer to the +// end of the table and increment the dispatch table's step version number. NEVER re-arrange +// the order of the member variables in a dispatch table. This will break the ABI. +// 2. In dire circumstances, if the type of an existing member variable in a dispatch +// table has be changed because a data type has been changed/removed, increment the dispatch +// table's major version number. If the function pointer type can no longer be declared, DO +// NOT REMOVE IT! Make the function pointer type void* and have it always be set to a nullptr. +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// The major version number should (ideally) never need to be incremented. +// - Increment the HIP_API_TABLE_MAJOR_VERSION for fundamental changes to the API table structs. +// - Increment the HIP_RUNTIME_API_TABLE_MAJOR_VERSION for fundamental changes to the +// HipDispatchTable struct, such as a *change* to type/name an existing member variable. DO NOT +// REMOVE IT. +// - Increment the HIP_COMPILER_API_TABLE_MAJOR_VERSION for fundamental changes to the +// HipCompilerDispatchTable struct, such as a *change* to type/name an existing member variable. +// DO NOT REMOVE IT. +#define HIP_API_TABLE_MAJOR_VERSION 0 +#define HIP_COMPILER_API_TABLE_MAJOR_VERSION 0 +#define HIP_RUNTIME_API_TABLE_MAJOR_VERSION 0 + +// The step version number should be changed whenever the size of the API table struct(s) change. +// - Increment the HIP_API_TABLE_STEP_VERSION when/if new API table structs are added +// - Increment the HIP_RUNTIME_API_TABLE_STEP_VERSION when new runtime API functions are added +// - Increment the HIP_COMPILER_API_TABLE_STEP_VERSION when new compiler API functions are added +// - Reset any of the *_STEP_VERSION defines to zero if the corresponding *_MAJOR_VERSION increases +#define HIP_API_TABLE_STEP_VERSION 0 +#define HIP_COMPILER_API_TABLE_STEP_VERSION 0 +#define HIP_RUNTIME_API_TABLE_STEP_VERSION 3 + // HIP API interface typedef hipError_t (*t___hipPopCallConfiguration)(dim3* gridDim, dim3* blockDim, size_t* sharedMem, hipStream_t* stream); @@ -255,6 +295,7 @@ typedef hipError_t (*t_hipGraphAddMemsetNode)(hipGraphNode_t* pGraphNode, hipGra const hipGraphNode_t* pDependencies, size_t numDependencies, const hipMemsetParams* pMemsetParams); + typedef hipError_t (*t_hipGraphChildGraphNodeGetGraph)(hipGraphNode_t node, hipGraph_t* pGraph); typedef hipError_t (*t_hipGraphClone)(hipGraph_t* pGraphClone, hipGraph_t originalGraph); typedef hipError_t (*t_hipGraphCreate)(hipGraph_t* pGraph, unsigned int flags); @@ -866,28 +907,68 @@ typedef hipError_t (*t_hipHccModuleLaunchKernel)(hipFunction_t f, uint32_t globa void** extra, hipEvent_t startEvent, hipEvent_t stopEvent); typedef int (*t_hipGetStreamDeviceId)(hipStream_t stream); - typedef hipError_t (*t_hipDrvGraphAddMemsetNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, const hipGraphNode_t* dependencies, size_t numDependencies, const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx); -typedef hipError_t (*t_hipGraphAddExternalSemaphoresWaitNode)(hipGraphNode_t* pGraphNode, hipGraph_t graph, - const hipGraphNode_t* pDependencies, size_t numDependencies, +typedef hipError_t (*t_hipGraphAddExternalSemaphoresWaitNode)(hipGraphNode_t* pGraphNode, + hipGraph_t graph, const hipGraphNode_t* pDependencies, + size_t numDependencies, const hipExternalSemaphoreWaitNodeParams* nodeParams); -typedef hipError_t (*t_hipGraphAddExternalSemaphoresSignalNode)(hipGraphNode_t* pGraphNode, hipGraph_t graph, - const hipGraphNode_t* pDependencies, size_t numDependencies, +typedef hipError_t (*t_hipGraphAddExternalSemaphoresSignalNode)(hipGraphNode_t* pGraphNode, + hipGraph_t graph, const hipGraphNode_t* pDependencies, + size_t numDependencies, const hipExternalSemaphoreSignalNodeParams* nodeParams); typedef hipError_t (*t_hipGraphExternalSemaphoresSignalNodeSetParams)(hipGraphNode_t hNode, - const hipExternalSemaphoreSignalNodeParams* nodeParams); + const hipExternalSemaphoreSignalNodeParams* nodeParams); typedef hipError_t (*t_hipGraphExternalSemaphoresWaitNodeSetParams)(hipGraphNode_t hNode, - const hipExternalSemaphoreWaitNodeParams* nodeParams); + const hipExternalSemaphoreWaitNodeParams* nodeParams); typedef hipError_t (*t_hipGraphExternalSemaphoresSignalNodeGetParams)(hipGraphNode_t hNode, - hipExternalSemaphoreSignalNodeParams* params_out); + hipExternalSemaphoreSignalNodeParams* params_out); typedef hipError_t (*t_hipGraphExternalSemaphoresWaitNodeGetParams)(hipGraphNode_t hNode, - hipExternalSemaphoreWaitNodeParams* params_out); -typedef hipError_t (*t_hipGraphExecExternalSemaphoresSignalNodeSetParams)(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, - const hipExternalSemaphoreSignalNodeParams* nodeParams); -typedef hipError_t (*t_hipGraphExecExternalSemaphoresWaitNodeSetParams)(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, - const hipExternalSemaphoreWaitNodeParams* nodeParams); + hipExternalSemaphoreWaitNodeParams* params_out); +typedef hipError_t (*t_hipGraphExecExternalSemaphoresSignalNodeSetParams)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + const hipExternalSemaphoreSignalNodeParams* nodeParams); +typedef hipError_t (*t_hipGraphExecExternalSemaphoresWaitNodeSetParams)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + const hipExternalSemaphoreWaitNodeParams* nodeParams); +typedef hipError_t (*t_hipGraphAddNode)(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipGraphNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphInstantiateWithParams)(hipGraphExec_t* pGraphExec, hipGraph_t graph, + hipGraphInstantiateParams* instantiateParams); +typedef hipError_t (*t_hipExtGetLastError)(); +typedef hipError_t (*t_hipTexRefGetBorderColor)(float* pBorderColor, + const textureReference* texRef); +typedef hipError_t (*t_hipTexRefGetArray)(hipArray_t* pArray, const textureReference* texRef); + +typedef hipError_t (*t_hipTexRefGetBorderColor)(float* pBorderColor, + const textureReference* texRef); +typedef hipError_t (*t_hipTexRefGetArray)(hipArray_t* pArray, const textureReference* texRef); +typedef hipError_t (*t_hipGetProcAddress)(const char* symbol, void** pfn, int hipVersion, uint64_t flags, + hipDriverProcAddressQueryResult* symbolStatus); +typedef hipError_t (*t_hipStreamBeginCaptureToGraph)(hipStream_t stream, hipGraph_t graph, + const hipGraphNode_t* dependencies, + const hipGraphEdgeData* dependencyData, + size_t numDependencies, + hipStreamCaptureMode mode); +typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t* functionPtr, const void* symbolPtr); +typedef hipError_t (*t_hipSetValidDevices)(int* device_arr, int len); +typedef hipError_t (*t_hipMemcpyAtoD)(hipDeviceptr_t dstDevice, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount); +typedef hipError_t (*t_hipMemcpyDtoA)(hipArray_t dstArray, size_t dstOffset, + hipDeviceptr_t srcDevice, size_t ByteCount); +typedef hipError_t (*t_hipMemcpyAtoA)(hipArray_t dstArray, size_t dstOffset, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount); +typedef hipError_t (*t_hipMemcpyAtoHAsync)(void* dstHost, hipArray_t srcArray, size_t srcOffset, + size_t ByteCount, hipStream_t stream); +typedef hipError_t (*t_hipMemcpyHtoAAsync)(hipArray_t dstArray, size_t dstOffset, + const void* srcHost, size_t ByteCount, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpy2DArrayToArray)(hipArray_t dst, size_t wOffsetDst, + size_t hOffsetDst, hipArray_const_t src, + size_t wOffsetSrc, size_t hOffsetSrc, size_t width, + size_t height, hipMemcpyKind kind); // HIP Compiler dispatch table struct HipCompilerDispatchTable { @@ -1347,4 +1428,19 @@ struct HipDispatchTable { t_hipGraphExternalSemaphoresWaitNodeGetParams hipGraphExternalSemaphoresWaitNodeGetParams_fn; t_hipGraphExecExternalSemaphoresSignalNodeSetParams hipGraphExecExternalSemaphoresSignalNodeSetParams_fn; t_hipGraphExecExternalSemaphoresWaitNodeSetParams hipGraphExecExternalSemaphoresWaitNodeSetParams_fn; + t_hipGraphAddNode hipGraphAddNode_fn; + t_hipGraphInstantiateWithParams hipGraphInstantiateWithParams_fn; + t_hipExtGetLastError hipExtGetLastError_fn; + t_hipTexRefGetBorderColor hipTexRefGetBorderColor_fn; + t_hipTexRefGetArray hipTexRefGetArray_fn; + t_hipGetProcAddress hipGetProcAddress_fn; + t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn; + t_hipGetFuncBySymbol hipGetFuncBySymbol_fn; + t_hipSetValidDevices hipSetValidDevices_fn; + t_hipMemcpyAtoD hipMemcpyAtoD_fn; + t_hipMemcpyDtoA hipMemcpyDtoA_fn; + t_hipMemcpyAtoA hipMemcpyAtoA_fn; + t_hipMemcpyAtoHAsync hipMemcpyAtoHAsync_fn; + t_hipMemcpyHtoAAsync hipMemcpyHtoAAsync_fn; + t_hipMemcpy2DArrayToArray hipMemcpy2DArrayToArray_fn; }; diff --git a/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h b/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h index 3c9c09f2cee8..992c198d0894 100644 --- a/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h +++ b/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h @@ -385,8 +385,8 @@ enum hip_api_id_t { HIP_API_ID_hipChooseDeviceR0600 = 365, HIP_API_ID_hipDrvGraphAddMemcpyNode = 366, HIP_API_ID_hipDrvGraphAddMemsetNode = 367, - HIP_API_ID_hipDrvGraphMemcpyNodeGetParams = 368, - HIP_API_ID_hipDrvGraphMemcpyNodeSetParams = 369, + HIP_API_ID_RESERVED_368 = 368, + HIP_API_ID_RESERVED_369 = 369, HIP_API_ID_hipGetDevicePropertiesR0600 = 370, HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode = 371, HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode = 372, @@ -397,7 +397,27 @@ enum hip_api_id_t { HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams = 377, HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams = 378, HIP_API_ID_hipExtGetLastError = 379, - HIP_API_ID_LAST = 379, + HIP_API_ID_hipGraphAddNode = 380, + HIP_API_ID_hipGetProcAddress = 381, + HIP_API_ID_RESERVED_382 = 382, + HIP_API_ID_RESERVED_383 = 383, + HIP_API_ID_hipGraphInstantiateWithParams = 384, + HIP_API_ID_RESERVED_385 = 385, + HIP_API_ID_RESERVED_386 = 386, + HIP_API_ID_RESERVED_387 = 387, + HIP_API_ID_RESERVED_388 = 388, + HIP_API_ID_hipTexRefGetArray = 389, + HIP_API_ID_hipTexRefGetBorderColor = 390, + HIP_API_ID_hipStreamBeginCaptureToGraph = 391, + HIP_API_ID_hipGetFuncBySymbol = 392, + HIP_API_ID_hipMemcpy2DArrayToArray = 393, + HIP_API_ID_hipMemcpyAtoA = 394, + HIP_API_ID_hipMemcpyAtoD = 395, + HIP_API_ID_hipMemcpyAtoHAsync = 396, + HIP_API_ID_hipMemcpyDtoA = 397, + HIP_API_ID_hipMemcpyHtoAAsync = 398, + HIP_API_ID_hipSetValidDevices = 399, + HIP_API_ID_LAST = 399, HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice), HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties), @@ -414,24 +434,14 @@ enum hip_api_id_t { HIP_API_ID_hipGetTextureObjectResourceViewDesc = HIP_API_ID_NONE, HIP_API_ID_hipGetTextureObjectTextureDesc = HIP_API_ID_NONE, HIP_API_ID_hipGetTextureReference = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpy2DArrayToArray = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpyAtoA = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpyAtoD = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpyAtoHAsync = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpyDtoA = HIP_API_ID_NONE, - HIP_API_ID_hipMemcpyHtoAAsync = HIP_API_ID_NONE, - HIP_API_ID_hipSetValidDevices = HIP_API_ID_NONE, HIP_API_ID_hipTexObjectCreate = HIP_API_ID_NONE, HIP_API_ID_hipTexObjectDestroy = HIP_API_ID_NONE, HIP_API_ID_hipTexObjectGetResourceDesc = HIP_API_ID_NONE, HIP_API_ID_hipTexObjectGetResourceViewDesc = HIP_API_ID_NONE, HIP_API_ID_hipTexObjectGetTextureDesc = HIP_API_ID_NONE, HIP_API_ID_hipTexRefGetAddressMode = HIP_API_ID_NONE, - HIP_API_ID_hipTexRefGetArray = HIP_API_ID_NONE, - HIP_API_ID_hipTexRefGetBorderColor = HIP_API_ID_NONE, HIP_API_ID_hipTexRefGetFilterMode = HIP_API_ID_NONE, HIP_API_ID_hipTexRefGetMipmapFilterMode = HIP_API_ID_NONE, - HIP_API_ID_hipTexRefGetMipmappedArray = HIP_API_ID_NONE, HIP_API_ID_hipTexRefSetAddressMode = HIP_API_ID_NONE, HIP_API_ID_hipTexRefSetFilterMode = HIP_API_ID_NONE, HIP_API_ID_hipTexRefSetMipmapFilterMode = HIP_API_ID_NONE, @@ -510,8 +520,6 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipDriverGetVersion: return "hipDriverGetVersion"; case HIP_API_ID_hipDrvGraphAddMemcpyNode: return "hipDrvGraphAddMemcpyNode"; case HIP_API_ID_hipDrvGraphAddMemsetNode: return "hipDrvGraphAddMemsetNode"; - case HIP_API_ID_hipDrvGraphMemcpyNodeGetParams: return "hipDrvGraphMemcpyNodeGetParams"; - case HIP_API_ID_hipDrvGraphMemcpyNodeSetParams: return "hipDrvGraphMemcpyNodeSetParams"; case HIP_API_ID_hipDrvMemcpy2DUnaligned: return "hipDrvMemcpy2DUnaligned"; case HIP_API_ID_hipDrvMemcpy3D: return "hipDrvMemcpy3D"; case HIP_API_ID_hipDrvMemcpy3DAsync: return "hipDrvMemcpy3DAsync"; @@ -523,6 +531,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipEventQuery: return "hipEventQuery"; case HIP_API_ID_hipEventRecord: return "hipEventRecord"; case HIP_API_ID_hipEventSynchronize: return "hipEventSynchronize"; + case HIP_API_ID_hipExtGetLastError: return "hipExtGetLastError"; case HIP_API_ID_hipExtGetLinkTypeAndHopCount: return "hipExtGetLinkTypeAndHopCount"; case HIP_API_ID_hipExtLaunchKernel: return "hipExtLaunchKernel"; case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: return "hipExtLaunchMultiKernelMultiDevice"; @@ -550,8 +559,10 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipGetDevicePropertiesR0000: return "hipGetDevicePropertiesR0000"; case HIP_API_ID_hipGetDevicePropertiesR0600: return "hipGetDevicePropertiesR0600"; case HIP_API_ID_hipGetErrorString: return "hipGetErrorString"; + case HIP_API_ID_hipGetFuncBySymbol: return "hipGetFuncBySymbol"; case HIP_API_ID_hipGetLastError: return "hipGetLastError"; case HIP_API_ID_hipGetMipmappedArrayLevel: return "hipGetMipmappedArrayLevel"; + case HIP_API_ID_hipGetProcAddress: return "hipGetProcAddress"; case HIP_API_ID_hipGetSymbolAddress: return "hipGetSymbolAddress"; case HIP_API_ID_hipGetSymbolSize: return "hipGetSymbolSize"; case HIP_API_ID_hipGraphAddChildGraphNode: return "hipGraphAddChildGraphNode"; @@ -570,6 +581,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol: return "hipGraphAddMemcpyNodeFromSymbol"; case HIP_API_ID_hipGraphAddMemcpyNodeToSymbol: return "hipGraphAddMemcpyNodeToSymbol"; case HIP_API_ID_hipGraphAddMemsetNode: return "hipGraphAddMemsetNode"; + case HIP_API_ID_hipGraphAddNode: return "hipGraphAddNode"; case HIP_API_ID_hipGraphChildGraphNodeGetGraph: return "hipGraphChildGraphNodeGetGraph"; case HIP_API_ID_hipGraphClone: return "hipGraphClone"; case HIP_API_ID_hipGraphCreate: return "hipGraphCreate"; @@ -605,6 +617,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipGraphHostNodeSetParams: return "hipGraphHostNodeSetParams"; case HIP_API_ID_hipGraphInstantiate: return "hipGraphInstantiate"; case HIP_API_ID_hipGraphInstantiateWithFlags: return "hipGraphInstantiateWithFlags"; + case HIP_API_ID_hipGraphInstantiateWithParams: return "hipGraphInstantiateWithParams"; case HIP_API_ID_hipGraphKernelNodeCopyAttributes: return "hipGraphKernelNodeCopyAttributes"; case HIP_API_ID_hipGraphKernelNodeGetAttribute: return "hipGraphKernelNodeGetAttribute"; case HIP_API_ID_hipGraphKernelNodeGetParams: return "hipGraphKernelNodeGetParams"; @@ -704,6 +717,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipMemUnmap: return "hipMemUnmap"; case HIP_API_ID_hipMemcpy: return "hipMemcpy"; case HIP_API_ID_hipMemcpy2D: return "hipMemcpy2D"; + case HIP_API_ID_hipMemcpy2DArrayToArray: return "hipMemcpy2DArrayToArray"; case HIP_API_ID_hipMemcpy2DAsync: return "hipMemcpy2DAsync"; case HIP_API_ID_hipMemcpy2DFromArray: return "hipMemcpy2DFromArray"; case HIP_API_ID_hipMemcpy2DFromArrayAsync: return "hipMemcpy2DFromArrayAsync"; @@ -712,7 +726,11 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipMemcpy3D: return "hipMemcpy3D"; case HIP_API_ID_hipMemcpy3DAsync: return "hipMemcpy3DAsync"; case HIP_API_ID_hipMemcpyAsync: return "hipMemcpyAsync"; + case HIP_API_ID_hipMemcpyAtoA: return "hipMemcpyAtoA"; + case HIP_API_ID_hipMemcpyAtoD: return "hipMemcpyAtoD"; case HIP_API_ID_hipMemcpyAtoH: return "hipMemcpyAtoH"; + case HIP_API_ID_hipMemcpyAtoHAsync: return "hipMemcpyAtoHAsync"; + case HIP_API_ID_hipMemcpyDtoA: return "hipMemcpyDtoA"; case HIP_API_ID_hipMemcpyDtoD: return "hipMemcpyDtoD"; case HIP_API_ID_hipMemcpyDtoDAsync: return "hipMemcpyDtoDAsync"; case HIP_API_ID_hipMemcpyDtoH: return "hipMemcpyDtoH"; @@ -721,6 +739,7 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipMemcpyFromSymbol: return "hipMemcpyFromSymbol"; case HIP_API_ID_hipMemcpyFromSymbolAsync: return "hipMemcpyFromSymbolAsync"; case HIP_API_ID_hipMemcpyHtoA: return "hipMemcpyHtoA"; + case HIP_API_ID_hipMemcpyHtoAAsync: return "hipMemcpyHtoAAsync"; case HIP_API_ID_hipMemcpyHtoD: return "hipMemcpyHtoD"; case HIP_API_ID_hipMemcpyHtoDAsync: return "hipMemcpyHtoDAsync"; case HIP_API_ID_hipMemcpyParam2D: return "hipMemcpyParam2D"; @@ -772,11 +791,13 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipRuntimeGetVersion: return "hipRuntimeGetVersion"; case HIP_API_ID_hipSetDevice: return "hipSetDevice"; case HIP_API_ID_hipSetDeviceFlags: return "hipSetDeviceFlags"; + case HIP_API_ID_hipSetValidDevices: return "hipSetValidDevices"; case HIP_API_ID_hipSetupArgument: return "hipSetupArgument"; case HIP_API_ID_hipSignalExternalSemaphoresAsync: return "hipSignalExternalSemaphoresAsync"; case HIP_API_ID_hipStreamAddCallback: return "hipStreamAddCallback"; case HIP_API_ID_hipStreamAttachMemAsync: return "hipStreamAttachMemAsync"; case HIP_API_ID_hipStreamBeginCapture: return "hipStreamBeginCapture"; + case HIP_API_ID_hipStreamBeginCaptureToGraph: return "hipStreamBeginCaptureToGraph"; case HIP_API_ID_hipStreamCreate: return "hipStreamCreate"; case HIP_API_ID_hipStreamCreateWithFlags: return "hipStreamCreateWithFlags"; case HIP_API_ID_hipStreamCreateWithPriority: return "hipStreamCreateWithPriority"; @@ -797,6 +818,8 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipStreamWriteValue32: return "hipStreamWriteValue32"; case HIP_API_ID_hipStreamWriteValue64: return "hipStreamWriteValue64"; case HIP_API_ID_hipTexRefGetAddress: return "hipTexRefGetAddress"; + case HIP_API_ID_hipTexRefGetArray: return "hipTexRefGetArray"; + case HIP_API_ID_hipTexRefGetBorderColor: return "hipTexRefGetBorderColor"; case HIP_API_ID_hipTexRefGetFlags: return "hipTexRefGetFlags"; case HIP_API_ID_hipTexRefGetFormat: return "hipTexRefGetFormat"; case HIP_API_ID_hipTexRefGetMaxAnisotropy: return "hipTexRefGetMaxAnisotropy"; @@ -818,7 +841,6 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipUserObjectRelease: return "hipUserObjectRelease"; case HIP_API_ID_hipUserObjectRetain: return "hipUserObjectRetain"; case HIP_API_ID_hipWaitExternalSemaphoresAsync: return "hipWaitExternalSemaphoresAsync"; - case HIP_API_ID_hipExtGetLastError: return "hipExtGetLastError"; }; return "unknown"; }; @@ -892,8 +914,6 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipDriverGetVersion", name) == 0) return HIP_API_ID_hipDriverGetVersion; if (strcmp("hipDrvGraphAddMemcpyNode", name) == 0) return HIP_API_ID_hipDrvGraphAddMemcpyNode; if (strcmp("hipDrvGraphAddMemsetNode", name) == 0) return HIP_API_ID_hipDrvGraphAddMemsetNode; - if (strcmp("hipDrvGraphMemcpyNodeGetParams", name) == 0) return HIP_API_ID_hipDrvGraphMemcpyNodeGetParams; - if (strcmp("hipDrvGraphMemcpyNodeSetParams", name) == 0) return HIP_API_ID_hipDrvGraphMemcpyNodeSetParams; if (strcmp("hipDrvMemcpy2DUnaligned", name) == 0) return HIP_API_ID_hipDrvMemcpy2DUnaligned; if (strcmp("hipDrvMemcpy3D", name) == 0) return HIP_API_ID_hipDrvMemcpy3D; if (strcmp("hipDrvMemcpy3DAsync", name) == 0) return HIP_API_ID_hipDrvMemcpy3DAsync; @@ -905,6 +925,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipEventQuery", name) == 0) return HIP_API_ID_hipEventQuery; if (strcmp("hipEventRecord", name) == 0) return HIP_API_ID_hipEventRecord; if (strcmp("hipEventSynchronize", name) == 0) return HIP_API_ID_hipEventSynchronize; + if (strcmp("hipExtGetLastError", name) == 0) return HIP_API_ID_hipExtGetLastError; if (strcmp("hipExtGetLinkTypeAndHopCount", name) == 0) return HIP_API_ID_hipExtGetLinkTypeAndHopCount; if (strcmp("hipExtLaunchKernel", name) == 0) return HIP_API_ID_hipExtLaunchKernel; if (strcmp("hipExtLaunchMultiKernelMultiDevice", name) == 0) return HIP_API_ID_hipExtLaunchMultiKernelMultiDevice; @@ -932,8 +953,10 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipGetDevicePropertiesR0000", name) == 0) return HIP_API_ID_hipGetDevicePropertiesR0000; if (strcmp("hipGetDevicePropertiesR0600", name) == 0) return HIP_API_ID_hipGetDevicePropertiesR0600; if (strcmp("hipGetErrorString", name) == 0) return HIP_API_ID_hipGetErrorString; + if (strcmp("hipGetFuncBySymbol", name) == 0) return HIP_API_ID_hipGetFuncBySymbol; if (strcmp("hipGetLastError", name) == 0) return HIP_API_ID_hipGetLastError; if (strcmp("hipGetMipmappedArrayLevel", name) == 0) return HIP_API_ID_hipGetMipmappedArrayLevel; + if (strcmp("hipGetProcAddress", name) == 0) return HIP_API_ID_hipGetProcAddress; if (strcmp("hipGetSymbolAddress", name) == 0) return HIP_API_ID_hipGetSymbolAddress; if (strcmp("hipGetSymbolSize", name) == 0) return HIP_API_ID_hipGetSymbolSize; if (strcmp("hipGraphAddChildGraphNode", name) == 0) return HIP_API_ID_hipGraphAddChildGraphNode; @@ -952,6 +975,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipGraphAddMemcpyNodeFromSymbol", name) == 0) return HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol; if (strcmp("hipGraphAddMemcpyNodeToSymbol", name) == 0) return HIP_API_ID_hipGraphAddMemcpyNodeToSymbol; if (strcmp("hipGraphAddMemsetNode", name) == 0) return HIP_API_ID_hipGraphAddMemsetNode; + if (strcmp("hipGraphAddNode", name) == 0) return HIP_API_ID_hipGraphAddNode; if (strcmp("hipGraphChildGraphNodeGetGraph", name) == 0) return HIP_API_ID_hipGraphChildGraphNodeGetGraph; if (strcmp("hipGraphClone", name) == 0) return HIP_API_ID_hipGraphClone; if (strcmp("hipGraphCreate", name) == 0) return HIP_API_ID_hipGraphCreate; @@ -987,6 +1011,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipGraphHostNodeSetParams", name) == 0) return HIP_API_ID_hipGraphHostNodeSetParams; if (strcmp("hipGraphInstantiate", name) == 0) return HIP_API_ID_hipGraphInstantiate; if (strcmp("hipGraphInstantiateWithFlags", name) == 0) return HIP_API_ID_hipGraphInstantiateWithFlags; + if (strcmp("hipGraphInstantiateWithParams", name) == 0) return HIP_API_ID_hipGraphInstantiateWithParams; if (strcmp("hipGraphKernelNodeCopyAttributes", name) == 0) return HIP_API_ID_hipGraphKernelNodeCopyAttributes; if (strcmp("hipGraphKernelNodeGetAttribute", name) == 0) return HIP_API_ID_hipGraphKernelNodeGetAttribute; if (strcmp("hipGraphKernelNodeGetParams", name) == 0) return HIP_API_ID_hipGraphKernelNodeGetParams; @@ -1086,6 +1111,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipMemUnmap", name) == 0) return HIP_API_ID_hipMemUnmap; if (strcmp("hipMemcpy", name) == 0) return HIP_API_ID_hipMemcpy; if (strcmp("hipMemcpy2D", name) == 0) return HIP_API_ID_hipMemcpy2D; + if (strcmp("hipMemcpy2DArrayToArray", name) == 0) return HIP_API_ID_hipMemcpy2DArrayToArray; if (strcmp("hipMemcpy2DAsync", name) == 0) return HIP_API_ID_hipMemcpy2DAsync; if (strcmp("hipMemcpy2DFromArray", name) == 0) return HIP_API_ID_hipMemcpy2DFromArray; if (strcmp("hipMemcpy2DFromArrayAsync", name) == 0) return HIP_API_ID_hipMemcpy2DFromArrayAsync; @@ -1094,7 +1120,11 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipMemcpy3D", name) == 0) return HIP_API_ID_hipMemcpy3D; if (strcmp("hipMemcpy3DAsync", name) == 0) return HIP_API_ID_hipMemcpy3DAsync; if (strcmp("hipMemcpyAsync", name) == 0) return HIP_API_ID_hipMemcpyAsync; + if (strcmp("hipMemcpyAtoA", name) == 0) return HIP_API_ID_hipMemcpyAtoA; + if (strcmp("hipMemcpyAtoD", name) == 0) return HIP_API_ID_hipMemcpyAtoD; if (strcmp("hipMemcpyAtoH", name) == 0) return HIP_API_ID_hipMemcpyAtoH; + if (strcmp("hipMemcpyAtoHAsync", name) == 0) return HIP_API_ID_hipMemcpyAtoHAsync; + if (strcmp("hipMemcpyDtoA", name) == 0) return HIP_API_ID_hipMemcpyDtoA; if (strcmp("hipMemcpyDtoD", name) == 0) return HIP_API_ID_hipMemcpyDtoD; if (strcmp("hipMemcpyDtoDAsync", name) == 0) return HIP_API_ID_hipMemcpyDtoDAsync; if (strcmp("hipMemcpyDtoH", name) == 0) return HIP_API_ID_hipMemcpyDtoH; @@ -1103,6 +1133,7 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipMemcpyFromSymbol", name) == 0) return HIP_API_ID_hipMemcpyFromSymbol; if (strcmp("hipMemcpyFromSymbolAsync", name) == 0) return HIP_API_ID_hipMemcpyFromSymbolAsync; if (strcmp("hipMemcpyHtoA", name) == 0) return HIP_API_ID_hipMemcpyHtoA; + if (strcmp("hipMemcpyHtoAAsync", name) == 0) return HIP_API_ID_hipMemcpyHtoAAsync; if (strcmp("hipMemcpyHtoD", name) == 0) return HIP_API_ID_hipMemcpyHtoD; if (strcmp("hipMemcpyHtoDAsync", name) == 0) return HIP_API_ID_hipMemcpyHtoDAsync; if (strcmp("hipMemcpyParam2D", name) == 0) return HIP_API_ID_hipMemcpyParam2D; @@ -1154,11 +1185,13 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipRuntimeGetVersion", name) == 0) return HIP_API_ID_hipRuntimeGetVersion; if (strcmp("hipSetDevice", name) == 0) return HIP_API_ID_hipSetDevice; if (strcmp("hipSetDeviceFlags", name) == 0) return HIP_API_ID_hipSetDeviceFlags; + if (strcmp("hipSetValidDevices", name) == 0) return HIP_API_ID_hipSetValidDevices; if (strcmp("hipSetupArgument", name) == 0) return HIP_API_ID_hipSetupArgument; if (strcmp("hipSignalExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipSignalExternalSemaphoresAsync; if (strcmp("hipStreamAddCallback", name) == 0) return HIP_API_ID_hipStreamAddCallback; if (strcmp("hipStreamAttachMemAsync", name) == 0) return HIP_API_ID_hipStreamAttachMemAsync; if (strcmp("hipStreamBeginCapture", name) == 0) return HIP_API_ID_hipStreamBeginCapture; + if (strcmp("hipStreamBeginCaptureToGraph", name) == 0) return HIP_API_ID_hipStreamBeginCaptureToGraph; if (strcmp("hipStreamCreate", name) == 0) return HIP_API_ID_hipStreamCreate; if (strcmp("hipStreamCreateWithFlags", name) == 0) return HIP_API_ID_hipStreamCreateWithFlags; if (strcmp("hipStreamCreateWithPriority", name) == 0) return HIP_API_ID_hipStreamCreateWithPriority; @@ -1179,6 +1212,8 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipStreamWriteValue32", name) == 0) return HIP_API_ID_hipStreamWriteValue32; if (strcmp("hipStreamWriteValue64", name) == 0) return HIP_API_ID_hipStreamWriteValue64; if (strcmp("hipTexRefGetAddress", name) == 0) return HIP_API_ID_hipTexRefGetAddress; + if (strcmp("hipTexRefGetArray", name) == 0) return HIP_API_ID_hipTexRefGetArray; + if (strcmp("hipTexRefGetBorderColor", name) == 0) return HIP_API_ID_hipTexRefGetBorderColor; if (strcmp("hipTexRefGetFlags", name) == 0) return HIP_API_ID_hipTexRefGetFlags; if (strcmp("hipTexRefGetFormat", name) == 0) return HIP_API_ID_hipTexRefGetFormat; if (strcmp("hipTexRefGetMaxAnisotropy", name) == 0) return HIP_API_ID_hipTexRefGetMaxAnisotropy; @@ -1200,7 +1235,6 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipUserObjectRelease", name) == 0) return HIP_API_ID_hipUserObjectRelease; if (strcmp("hipUserObjectRetain", name) == 0) return HIP_API_ID_hipUserObjectRetain; if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipWaitExternalSemaphoresAsync; - if (strcmp("hipExtGetLastError", name) == 0) return HIP_API_ID_hipExtGetLastError; return HIP_API_ID_NONE; } @@ -1519,16 +1553,6 @@ typedef struct hip_api_data_s { HIP_MEMSET_NODE_PARAMS memsetParams__val; hipCtx_t ctx; } hipDrvGraphAddMemsetNode; - struct { - hipGraphNode_t hNode; - HIP_MEMCPY3D* nodeParams; - HIP_MEMCPY3D nodeParams__val; - } hipDrvGraphMemcpyNodeGetParams; - struct { - hipGraphNode_t hNode; - const HIP_MEMCPY3D* nodeParams; - HIP_MEMCPY3D nodeParams__val; - } hipDrvGraphMemcpyNodeSetParams; struct { const hip_Memcpy2D* pCopy; hip_Memcpy2D pCopy__val; @@ -1730,12 +1754,27 @@ typedef struct hip_api_data_s { hipDeviceProp_tR0600 prop__val; int deviceId; } hipGetDevicePropertiesR0600; + struct { + hipFunction_t* functionPtr; + hipFunction_t functionPtr__val; + const void* symbolPtr; + } hipGetFuncBySymbol; struct { hipArray_t* levelArray; hipArray_t levelArray__val; hipMipmappedArray_const_t mipmappedArray; unsigned int level; } hipGetMipmappedArrayLevel; + struct { + const char* symbol; + char symbol__val; + void** pfn; + void* pfn__val; + int hipVersion; + uint64_t flags; + hipDriverProcAddressQueryResult* symbolStatus; + hipDriverProcAddressQueryResult symbolStatus__val; + } hipGetProcAddress; struct { void** devPtr; void* devPtr__val; @@ -1906,6 +1945,16 @@ typedef struct hip_api_data_s { const hipMemsetParams* pMemsetParams; hipMemsetParams pMemsetParams__val; } hipGraphAddMemsetNode; + struct { + hipGraphNode_t* pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t* pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipGraphNodeParams* nodeParams; + hipGraphNodeParams nodeParams__val; + } hipGraphAddNode; struct { hipGraphNode_t node; hipGraph_t* pGraph; @@ -2108,15 +2157,22 @@ typedef struct hip_api_data_s { hipGraph_t graph; unsigned long long flags; } hipGraphInstantiateWithFlags; + struct { + hipGraphExec_t* pGraphExec; + hipGraphExec_t pGraphExec__val; + hipGraph_t graph; + hipGraphInstantiateParams* instantiateParams; + hipGraphInstantiateParams instantiateParams__val; + } hipGraphInstantiateWithParams; struct { hipGraphNode_t hSrc; hipGraphNode_t hDst; } hipGraphKernelNodeCopyAttributes; struct { hipGraphNode_t hNode; - hipKernelNodeAttrID attr; - hipKernelNodeAttrValue* value; - hipKernelNodeAttrValue value__val; + hipLaunchAttributeID attr; + hipLaunchAttributeValue* value; + hipLaunchAttributeValue value__val; } hipGraphKernelNodeGetAttribute; struct { hipGraphNode_t node; @@ -2125,9 +2181,9 @@ typedef struct hip_api_data_s { } hipGraphKernelNodeGetParams; struct { hipGraphNode_t hNode; - hipKernelNodeAttrID attr; - const hipKernelNodeAttrValue* value; - hipKernelNodeAttrValue value__val; + hipLaunchAttributeID attr; + const hipLaunchAttributeValue* value; + hipLaunchAttributeValue value__val; } hipGraphKernelNodeSetAttribute; struct { hipGraphNode_t node; @@ -2702,6 +2758,17 @@ typedef struct hip_api_data_s { size_t height; hipMemcpyKind kind; } hipMemcpy2D; + struct { + hipArray_t dst; + size_t wOffsetDst; + size_t hOffsetDst; + hipArray_const_t src; + size_t wOffsetSrc; + size_t hOffsetSrc; + size_t width; + size_t height; + hipMemcpyKind kind; + } hipMemcpy2DArrayToArray; struct { void* dst; size_t dpitch; @@ -2770,12 +2837,38 @@ typedef struct hip_api_data_s { hipMemcpyKind kind; hipStream_t stream; } hipMemcpyAsync; + struct { + hipArray_t dstArray; + size_t dstOffset; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + } hipMemcpyAtoA; + struct { + hipDeviceptr_t dstDevice; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + } hipMemcpyAtoD; struct { void* dst; hipArray_t srcArray; size_t srcOffset; size_t count; } hipMemcpyAtoH; + struct { + void* dstHost; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + hipStream_t stream; + } hipMemcpyAtoHAsync; + struct { + hipArray_t dstArray; + size_t dstOffset; + hipDeviceptr_t srcDevice; + size_t ByteCount; + } hipMemcpyDtoA; struct { hipDeviceptr_t dst; hipDeviceptr_t src; @@ -2827,6 +2920,13 @@ typedef struct hip_api_data_s { const void* srcHost; size_t count; } hipMemcpyHtoA; + struct { + hipArray_t dstArray; + size_t dstOffset; + const void* srcHost; + size_t ByteCount; + hipStream_t stream; + } hipMemcpyHtoAAsync; struct { hipDeviceptr_t dst; void* src; @@ -3142,6 +3242,11 @@ typedef struct hip_api_data_s { struct { unsigned int flags; } hipSetDeviceFlags; + struct { + int* device_arr; + int device_arr__val; + int len; + } hipSetValidDevices; struct { const void* arg; size_t size; @@ -3171,6 +3276,16 @@ typedef struct hip_api_data_s { hipStream_t stream; hipStreamCaptureMode mode; } hipStreamBeginCapture; + struct { + hipStream_t stream; + hipGraph_t graph; + const hipGraphNode_t* dependencies; + hipGraphNode_t dependencies__val; + const hipGraphEdgeData* dependencyData; + hipGraphEdgeData dependencyData__val; + size_t numDependencies; + hipStreamCaptureMode mode; + } hipStreamBeginCaptureToGraph; struct { hipStream_t* stream; hipStream_t stream__val; @@ -3284,6 +3399,18 @@ typedef struct hip_api_data_s { const textureReference* texRef; textureReference texRef__val; } hipTexRefGetAddress; + struct { + hipArray_t* pArray; + hipArray_t pArray__val; + const textureReference* texRef; + textureReference texRef__val; + } hipTexRefGetArray; + struct { + float* pBorderColor; + float pBorderColor__val; + const textureReference* texRef; + textureReference texRef__val; + } hipTexRefGetBorderColor; struct { unsigned int* pFlags; unsigned int pFlags__val; @@ -3729,15 +3856,21 @@ typedef struct hip_api_data_s { }; // hipDrvGraphAddMemcpyNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), ('const HIP_MEMCPY3D*', 'copyParams'), ('hipCtx_t', 'ctx')] #define INIT_hipDrvGraphAddMemcpyNode_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipDrvGraphAddMemcpyNode.phGraphNode = (hipGraphNode_t*)phGraphNode; \ + cb_data.args.hipDrvGraphAddMemcpyNode.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipDrvGraphAddMemcpyNode.dependencies = (const hipGraphNode_t*)dependencies; \ + cb_data.args.hipDrvGraphAddMemcpyNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipDrvGraphAddMemcpyNode.copyParams = (const HIP_MEMCPY3D*)copyParams; \ + cb_data.args.hipDrvGraphAddMemcpyNode.ctx = (hipCtx_t)ctx; \ }; // hipDrvGraphAddMemsetNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), ('const HIP_MEMSET_NODE_PARAMS*', 'memsetParams'), ('hipCtx_t', 'ctx')] #define INIT_hipDrvGraphAddMemsetNode_CB_ARGS_DATA(cb_data) { \ -}; -// hipDrvGraphMemcpyNodeGetParams[('hipGraphNode_t', 'hNode'), ('HIP_MEMCPY3D*', 'nodeParams')] -#define INIT_hipDrvGraphMemcpyNodeGetParams_CB_ARGS_DATA(cb_data) { \ -}; -// hipDrvGraphMemcpyNodeSetParams[('hipGraphNode_t', 'hNode'), ('const HIP_MEMCPY3D*', 'nodeParams')] -#define INIT_hipDrvGraphMemcpyNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipDrvGraphAddMemsetNode.phGraphNode = (hipGraphNode_t*)phGraphNode; \ + cb_data.args.hipDrvGraphAddMemsetNode.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipDrvGraphAddMemsetNode.dependencies = (const hipGraphNode_t*)dependencies; \ + cb_data.args.hipDrvGraphAddMemsetNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipDrvGraphAddMemsetNode.memsetParams = (const HIP_MEMSET_NODE_PARAMS*)memsetParams; \ + cb_data.args.hipDrvGraphAddMemsetNode.ctx = (hipCtx_t)ctx; \ }; // hipDrvMemcpy2DUnaligned[('const hip_Memcpy2D*', 'pCopy')] #define INIT_hipDrvMemcpy2DUnaligned_CB_ARGS_DATA(cb_data) { \ @@ -3791,6 +3924,9 @@ typedef struct hip_api_data_s { #define INIT_hipEventSynchronize_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipEventSynchronize.event = (hipEvent_t)event; \ }; +// hipExtGetLastError[] +#define INIT_hipExtGetLastError_CB_ARGS_DATA(cb_data) { \ +}; // hipExtGetLinkTypeAndHopCount[('int', 'device1'), ('int', 'device2'), ('unsigned int*', 'linktype'), ('unsigned int*', 'hopcount')] #define INIT_hipExtGetLinkTypeAndHopCount_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipExtGetLinkTypeAndHopCount.device1 = (int)device1; \ @@ -3948,6 +4084,11 @@ typedef struct hip_api_data_s { // hipGetErrorString[] #define INIT_hipGetErrorString_CB_ARGS_DATA(cb_data) { \ }; +// hipGetFuncBySymbol[('hipFunction_t*', 'functionPtr'), ('const void*', 'symbolPtr')] +#define INIT_hipGetFuncBySymbol_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGetFuncBySymbol.functionPtr = (hipFunction_t*)functionPtr; \ + cb_data.args.hipGetFuncBySymbol.symbolPtr = (const void*)symbolPtr; \ +}; // hipGetLastError[] #define INIT_hipGetLastError_CB_ARGS_DATA(cb_data) { \ }; @@ -3957,6 +4098,14 @@ typedef struct hip_api_data_s { cb_data.args.hipGetMipmappedArrayLevel.mipmappedArray = (hipMipmappedArray_const_t)mipmappedArray; \ cb_data.args.hipGetMipmappedArrayLevel.level = (unsigned int)level; \ }; +// hipGetProcAddress[('const char*', 'symbol'), ('void**', 'pfn'), ('int', 'hipVersion'), ('uint64_t', 'flags'), ('hipDriverProcAddressQueryResult*', 'symbolStatus')] +#define INIT_hipGetProcAddress_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGetProcAddress.symbol = (symbol) ? strdup(symbol) : NULL; \ + cb_data.args.hipGetProcAddress.pfn = (void**)pfn; \ + cb_data.args.hipGetProcAddress.hipVersion = (int)hipVersion; \ + cb_data.args.hipGetProcAddress.flags = (uint64_t)flags; \ + cb_data.args.hipGetProcAddress.symbolStatus = (hipDriverProcAddressQueryResult*)symbolStatus; \ +}; // hipGetSymbolAddress[('void**', 'devPtr'), ('const void*', 'symbol')] #define INIT_hipGetSymbolAddress_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipGetSymbolAddress.devPtr = (void**)devPtr; \ @@ -4007,9 +4156,19 @@ typedef struct hip_api_data_s { }; // hipGraphAddExternalSemaphoresSignalNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), ('const hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] #define INIT_hipGraphAddExternalSemaphoresSignalNode_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode = (hipGraphNode_t*)pGraphNode; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.pDependencies = (const hipGraphNode_t*)pDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.nodeParams = (const hipExternalSemaphoreSignalNodeParams*)nodeParams; \ }; // hipGraphAddExternalSemaphoresWaitNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), ('const hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] #define INIT_hipGraphAddExternalSemaphoresWaitNode_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode = (hipGraphNode_t*)pGraphNode; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.pDependencies = (const hipGraphNode_t*)pDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.nodeParams = (const hipExternalSemaphoreWaitNodeParams*)nodeParams; \ }; // hipGraphAddHostNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), ('const hipHostNodeParams*', 'pNodeParams')] #define INIT_hipGraphAddHostNode_CB_ARGS_DATA(cb_data) { \ @@ -4094,6 +4253,14 @@ typedef struct hip_api_data_s { cb_data.args.hipGraphAddMemsetNode.numDependencies = (size_t)numDependencies; \ cb_data.args.hipGraphAddMemsetNode.pMemsetParams = (const hipMemsetParams*)pMemsetParams; \ }; +// hipGraphAddNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), ('hipGraphNodeParams*', 'nodeParams')] +#define INIT_hipGraphAddNode_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphAddNode.pGraphNode = (hipGraphNode_t*)pGraphNode; \ + cb_data.args.hipGraphAddNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddNode.pDependencies = (const hipGraphNode_t*)pDependencies; \ + cb_data.args.hipGraphAddNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipGraphAddNode.nodeParams = (hipGraphNodeParams*)nodeParams; \ +}; // hipGraphChildGraphNodeGetGraph[('hipGraphNode_t', 'node'), ('hipGraph_t*', 'pGraph')] #define INIT_hipGraphChildGraphNodeGetGraph_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipGraphChildGraphNodeGetGraph.node = (hipGraphNode_t)node; \ @@ -4167,9 +4334,15 @@ typedef struct hip_api_data_s { }; // hipGraphExecExternalSemaphoresSignalNodeSetParams[('hipGraphExec_t', 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] #define INIT_hipGraphExecExternalSemaphoresSignalNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams.hGraphExec = (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams.nodeParams = (const hipExternalSemaphoreSignalNodeParams*)nodeParams; \ }; // hipGraphExecExternalSemaphoresWaitNodeSetParams[('hipGraphExec_t', 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] #define INIT_hipGraphExecExternalSemaphoresWaitNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hGraphExec = (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.nodeParams = (const hipExternalSemaphoreWaitNodeParams*)nodeParams; \ }; // hipGraphExecHostNodeSetParams[('hipGraphExec_t', 'hGraphExec'), ('hipGraphNode_t', 'node'), ('const hipHostNodeParams*', 'pNodeParams')] #define INIT_hipGraphExecHostNodeSetParams_CB_ARGS_DATA(cb_data) { \ @@ -4233,15 +4406,23 @@ typedef struct hip_api_data_s { }; // hipGraphExternalSemaphoresSignalNodeGetParams[('hipGraphNode_t', 'hNode'), ('hipExternalSemaphoreSignalNodeParams*', 'params_out')] #define INIT_hipGraphExternalSemaphoresSignalNodeGetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeGetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeGetParams.params_out = (hipExternalSemaphoreSignalNodeParams*)params_out; \ }; // hipGraphExternalSemaphoresSignalNodeSetParams[('hipGraphNode_t', 'hNode'), ('const hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] #define INIT_hipGraphExternalSemaphoresSignalNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeSetParams.nodeParams = (const hipExternalSemaphoreSignalNodeParams*)nodeParams; \ }; // hipGraphExternalSemaphoresWaitNodeGetParams[('hipGraphNode_t', 'hNode'), ('hipExternalSemaphoreWaitNodeParams*', 'params_out')] #define INIT_hipGraphExternalSemaphoresWaitNodeGetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeGetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out = (hipExternalSemaphoreWaitNodeParams*)params_out; \ }; // hipGraphExternalSemaphoresWaitNodeSetParams[('hipGraphNode_t', 'hNode'), ('const hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] #define INIT_hipGraphExternalSemaphoresWaitNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams = (const hipExternalSemaphoreWaitNodeParams*)nodeParams; \ }; // hipGraphGetEdges[('hipGraph_t', 'graph'), ('hipGraphNode_t*', 'from'), ('hipGraphNode_t*', 'to'), ('size_t*', 'numEdges')] #define INIT_hipGraphGetEdges_CB_ARGS_DATA(cb_data) { \ @@ -4286,27 +4467,27 @@ typedef struct hip_api_data_s { cb_data.args.hipGraphInstantiateWithFlags.graph = (hipGraph_t)graph; \ cb_data.args.hipGraphInstantiateWithFlags.flags = (unsigned long long)flags; \ }; +// hipGraphInstantiateWithParams[('hipGraphExec_t*', 'pGraphExec'), ('hipGraph_t', 'graph'), ('hipGraphInstantiateParams*', 'instantiateParams')] +#define INIT_hipGraphInstantiateWithParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphInstantiateWithParams.pGraphExec = (hipGraphExec_t*)pGraphExec; \ + cb_data.args.hipGraphInstantiateWithParams.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphInstantiateWithParams.instantiateParams = (hipGraphInstantiateParams*)instantiateParams; \ +}; // hipGraphKernelNodeCopyAttributes[('hipGraphNode_t', 'hSrc'), ('hipGraphNode_t', 'hDst')] #define INIT_hipGraphKernelNodeCopyAttributes_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipGraphKernelNodeCopyAttributes.hSrc = (hipGraphNode_t)hSrc; \ cb_data.args.hipGraphKernelNodeCopyAttributes.hDst = (hipGraphNode_t)hDst; \ }; -// hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), ('hipKernelNodeAttrID', 'attr'), ('hipKernelNodeAttrValue*', 'value')] +// hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), ('hipLaunchAttributeID', 'attr'), ('hipLaunchAttributeValue*', 'value')] #define INIT_hipGraphKernelNodeGetAttribute_CB_ARGS_DATA(cb_data) { \ - cb_data.args.hipGraphKernelNodeGetAttribute.hNode = (hipGraphNode_t)hNode; \ - cb_data.args.hipGraphKernelNodeGetAttribute.attr = (hipKernelNodeAttrID)attr; \ - cb_data.args.hipGraphKernelNodeGetAttribute.value = (hipKernelNodeAttrValue*)value; \ }; // hipGraphKernelNodeGetParams[('hipGraphNode_t', 'node'), ('hipKernelNodeParams*', 'pNodeParams')] #define INIT_hipGraphKernelNodeGetParams_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipGraphKernelNodeGetParams.node = (hipGraphNode_t)node; \ cb_data.args.hipGraphKernelNodeGetParams.pNodeParams = (hipKernelNodeParams*)pNodeParams; \ }; -// hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), ('hipKernelNodeAttrID', 'attr'), ('const hipKernelNodeAttrValue*', 'value')] +// hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), ('hipLaunchAttributeID', 'attr'), ('const hipLaunchAttributeValue*', 'value')] #define INIT_hipGraphKernelNodeSetAttribute_CB_ARGS_DATA(cb_data) { \ - cb_data.args.hipGraphKernelNodeSetAttribute.hNode = (hipGraphNode_t)hNode; \ - cb_data.args.hipGraphKernelNodeSetAttribute.attr = (hipKernelNodeAttrID)attr; \ - cb_data.args.hipGraphKernelNodeSetAttribute.value = (const hipKernelNodeAttrValue*)value; \ }; // hipGraphKernelNodeSetParams[('hipGraphNode_t', 'node'), ('const hipKernelNodeParams*', 'pNodeParams')] #define INIT_hipGraphKernelNodeSetParams_CB_ARGS_DATA(cb_data) { \ @@ -4891,6 +5072,18 @@ typedef struct hip_api_data_s { cb_data.args.hipMemcpy2D.height = (size_t)height; \ cb_data.args.hipMemcpy2D.kind = (hipMemcpyKind)kind; \ }; +// hipMemcpy2DArrayToArray[('hipArray_t', 'dst'), ('size_t', 'wOffsetDst'), ('size_t', 'hOffsetDst'), ('hipArray_const_t', 'src'), ('size_t', 'wOffsetSrc'), ('size_t', 'hOffsetSrc'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy2DArrayToArray_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpy2DArrayToArray.dst = (hipArray_t)dst; \ + cb_data.args.hipMemcpy2DArrayToArray.wOffsetDst = (size_t)wOffsetDst; \ + cb_data.args.hipMemcpy2DArrayToArray.hOffsetDst = (size_t)hOffsetDst; \ + cb_data.args.hipMemcpy2DArrayToArray.src = (hipArray_const_t)src; \ + cb_data.args.hipMemcpy2DArrayToArray.wOffsetSrc = (size_t)wOffsetSrc; \ + cb_data.args.hipMemcpy2DArrayToArray.hOffsetSrc = (size_t)hOffsetSrc; \ + cb_data.args.hipMemcpy2DArrayToArray.width = (size_t)width; \ + cb_data.args.hipMemcpy2DArrayToArray.height = (size_t)height; \ + cb_data.args.hipMemcpy2DArrayToArray.kind = (hipMemcpyKind)kind; \ +}; // hipMemcpy2DAsync[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] #define INIT_hipMemcpy2DAsync_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipMemcpy2DAsync.dst = (void*)dst; \ @@ -4965,6 +5158,21 @@ typedef struct hip_api_data_s { cb_data.args.hipMemcpyAsync.kind = (hipMemcpyKind)kind; \ cb_data.args.hipMemcpyAsync.stream = (hipStream_t)stream; \ }; +// hipMemcpyAtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyAtoA_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpyAtoA.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyAtoA.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyAtoA.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoA.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoA.ByteCount = (size_t)ByteCount; \ +}; +// hipMemcpyAtoD[('hipDeviceptr_t', 'dstDevice'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyAtoD_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpyAtoD.dstDevice = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyAtoD.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoD.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoD.ByteCount = (size_t)ByteCount; \ +}; // hipMemcpyAtoH[('void*', 'dst'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'count')] #define INIT_hipMemcpyAtoH_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipMemcpyAtoH.dst = (void*)dstHost; \ @@ -4972,6 +5180,21 @@ typedef struct hip_api_data_s { cb_data.args.hipMemcpyAtoH.srcOffset = (size_t)srcOffset; \ cb_data.args.hipMemcpyAtoH.count = (size_t)ByteCount; \ }; +// hipMemcpyAtoHAsync[('void*', 'dstHost'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyAtoHAsync_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpyAtoHAsync.dstHost = (void*)dstHost; \ + cb_data.args.hipMemcpyAtoHAsync.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoHAsync.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoHAsync.ByteCount = (size_t)ByteCount; \ + cb_data.args.hipMemcpyAtoHAsync.stream = (hipStream_t)stream; \ +}; +// hipMemcpyDtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('hipDeviceptr_t', 'srcDevice'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyDtoA_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpyDtoA.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyDtoA.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyDtoA.srcDevice = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoA.ByteCount = (size_t)ByteCount; \ +}; // hipMemcpyDtoD[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), ('size_t', 'sizeBytes')] #define INIT_hipMemcpyDtoD_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipMemcpyDtoD.dst = (hipDeviceptr_t)dstDevice; \ @@ -5031,6 +5254,14 @@ typedef struct hip_api_data_s { cb_data.args.hipMemcpyHtoA.srcHost = (const void*)srcHost; \ cb_data.args.hipMemcpyHtoA.count = (size_t)ByteCount; \ }; +// hipMemcpyHtoAAsync[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('const void*', 'srcHost'), ('size_t', 'ByteCount'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyHtoAAsync_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipMemcpyHtoAAsync.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyHtoAAsync.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyHtoAAsync.srcHost = (const void*)srcHost; \ + cb_data.args.hipMemcpyHtoAAsync.ByteCount = (size_t)ByteCount; \ + cb_data.args.hipMemcpyHtoAAsync.stream = (hipStream_t)stream; \ +}; // hipMemcpyHtoD[('hipDeviceptr_t', 'dst'), ('void*', 'src'), ('size_t', 'sizeBytes')] #define INIT_hipMemcpyHtoD_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipMemcpyHtoD.dst = (hipDeviceptr_t)dstDevice; \ @@ -5369,6 +5600,11 @@ typedef struct hip_api_data_s { #define INIT_hipSetDeviceFlags_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipSetDeviceFlags.flags = (unsigned int)flags; \ }; +// hipSetValidDevices[('int*', 'device_arr'), ('int', 'len')] +#define INIT_hipSetValidDevices_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipSetValidDevices.device_arr = (int*)device_arr; \ + cb_data.args.hipSetValidDevices.len = (int)len; \ +}; // hipSetupArgument[('const void*', 'arg'), ('size_t', 'size'), ('size_t', 'offset')] #define INIT_hipSetupArgument_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipSetupArgument.arg = (const void*)arg; \ @@ -5401,6 +5637,15 @@ typedef struct hip_api_data_s { cb_data.args.hipStreamBeginCapture.stream = (hipStream_t)stream; \ cb_data.args.hipStreamBeginCapture.mode = (hipStreamCaptureMode)mode; \ }; +// hipStreamBeginCaptureToGraph[('hipStream_t', 'stream'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'dependencies'), ('const hipGraphEdgeData*', 'dependencyData'), ('size_t', 'numDependencies'), ('hipStreamCaptureMode', 'mode')] +#define INIT_hipStreamBeginCaptureToGraph_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipStreamBeginCaptureToGraph.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamBeginCaptureToGraph.graph = (hipGraph_t)graph; \ + cb_data.args.hipStreamBeginCaptureToGraph.dependencies = (const hipGraphNode_t*)dependencies; \ + cb_data.args.hipStreamBeginCaptureToGraph.dependencyData = (const hipGraphEdgeData*)dependencyData; \ + cb_data.args.hipStreamBeginCaptureToGraph.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipStreamBeginCaptureToGraph.mode = (hipStreamCaptureMode)mode; \ +}; // hipStreamCreate[('hipStream_t*', 'stream')] #define INIT_hipStreamCreate_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipStreamCreate.stream = (hipStream_t*)stream; \ @@ -5516,6 +5761,16 @@ typedef struct hip_api_data_s { cb_data.args.hipTexRefGetAddress.dev_ptr = (hipDeviceptr_t*)dptr; \ cb_data.args.hipTexRefGetAddress.texRef = (const textureReference*)texRef; \ }; +// hipTexRefGetArray[('hipArray_t*', 'pArray'), ('const textureReference*', 'texRef')] +#define INIT_hipTexRefGetArray_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipTexRefGetArray.pArray = (hipArray_t*)pArray; \ + cb_data.args.hipTexRefGetArray.texRef = (const textureReference*)texRef; \ +}; +// hipTexRefGetBorderColor[('float*', 'pBorderColor'), ('const textureReference*', 'texRef')] +#define INIT_hipTexRefGetBorderColor_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipTexRefGetBorderColor.pBorderColor = (float*)pBorderColor; \ + cb_data.args.hipTexRefGetBorderColor.texRef = (const textureReference*)texRef; \ +}; // hipTexRefGetFlags[('unsigned int*', 'pFlags'), ('const textureReference*', 'texRef')] #define INIT_hipTexRefGetFlags_CB_ARGS_DATA(cb_data) { \ cb_data.args.hipTexRefGetFlags.pFlags = (unsigned int*)pFlags; \ @@ -5534,6 +5789,8 @@ typedef struct hip_api_data_s { }; // hipTexRefGetMipMappedArray[('hipMipmappedArray_t*', 'pArray'), ('const textureReference*', 'texRef')] #define INIT_hipTexRefGetMipMappedArray_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipTexRefGetMipMappedArray.pArray = (hipMipmappedArray_t*)pArray; \ + cb_data.args.hipTexRefGetMipMappedArray.texRef = (const textureReference*)texRef; \ }; // hipTexRefGetMipmapLevelBias[('float*', 'pbias'), ('const textureReference*', 'texRef')] #define INIT_hipTexRefGetMipmapLevelBias_CB_ARGS_DATA(cb_data) { \ @@ -5633,9 +5890,6 @@ typedef struct hip_api_data_s { cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = (unsigned int)numExtSems; \ cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \ }; -// hipExtGetLastError[] -#define INIT_hipExtGetLastError_CB_ARGS_DATA(cb_data) { \ -}; #define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data) // Macros for non-public API primitives @@ -5663,20 +5917,6 @@ typedef struct hip_api_data_s { #define INIT_hipGetTextureObjectTextureDesc_CB_ARGS_DATA(cb_data) {}; // hipGetTextureReference() #define INIT_hipGetTextureReference_CB_ARGS_DATA(cb_data) {}; -// hipMemcpy2DArrayToArray() -#define INIT_hipMemcpy2DArrayToArray_CB_ARGS_DATA(cb_data) {}; -// hipMemcpyAtoA() -#define INIT_hipMemcpyAtoA_CB_ARGS_DATA(cb_data) {}; -// hipMemcpyAtoD() -#define INIT_hipMemcpyAtoD_CB_ARGS_DATA(cb_data) {}; -// hipMemcpyAtoHAsync() -#define INIT_hipMemcpyAtoHAsync_CB_ARGS_DATA(cb_data) {}; -// hipMemcpyDtoA() -#define INIT_hipMemcpyDtoA_CB_ARGS_DATA(cb_data) {}; -// hipMemcpyHtoAAsync() -#define INIT_hipMemcpyHtoAAsync_CB_ARGS_DATA(cb_data) {}; -// hipSetValidDevices() -#define INIT_hipSetValidDevices_CB_ARGS_DATA(cb_data) {}; // hipTexObjectCreate() #define INIT_hipTexObjectCreate_CB_ARGS_DATA(cb_data) {}; // hipTexObjectDestroy() @@ -5689,16 +5929,10 @@ typedef struct hip_api_data_s { #define INIT_hipTexObjectGetTextureDesc_CB_ARGS_DATA(cb_data) {}; // hipTexRefGetAddressMode() #define INIT_hipTexRefGetAddressMode_CB_ARGS_DATA(cb_data) {}; -// hipTexRefGetArray() -#define INIT_hipTexRefGetArray_CB_ARGS_DATA(cb_data) {}; -// hipTexRefGetBorderColor() -#define INIT_hipTexRefGetBorderColor_CB_ARGS_DATA(cb_data) {}; // hipTexRefGetFilterMode() #define INIT_hipTexRefGetFilterMode_CB_ARGS_DATA(cb_data) {}; // hipTexRefGetMipmapFilterMode() #define INIT_hipTexRefGetMipmapFilterMode_CB_ARGS_DATA(cb_data) {}; -// hipTexRefGetMipmappedArray() -#define INIT_hipTexRefGetMipmappedArray_CB_ARGS_DATA(cb_data) {}; // hipTexRefSetAddressMode() #define INIT_hipTexRefSetAddressMode_CB_ARGS_DATA(cb_data) {}; // hipTexRefSetFilterMode() @@ -5968,14 +6202,6 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { if (data->args.hipDrvGraphAddMemsetNode.dependencies) data->args.hipDrvGraphAddMemsetNode.dependencies__val = *(data->args.hipDrvGraphAddMemsetNode.dependencies); if (data->args.hipDrvGraphAddMemsetNode.memsetParams) data->args.hipDrvGraphAddMemsetNode.memsetParams__val = *(data->args.hipDrvGraphAddMemsetNode.memsetParams); break; -// hipDrvGraphMemcpyNodeGetParams[('hipGraphNode_t', 'hNode'), ('HIP_MEMCPY3D*', 'nodeParams')] - case HIP_API_ID_hipDrvGraphMemcpyNodeGetParams: - if (data->args.hipDrvGraphMemcpyNodeGetParams.nodeParams) data->args.hipDrvGraphMemcpyNodeGetParams.nodeParams__val = *(data->args.hipDrvGraphMemcpyNodeGetParams.nodeParams); - break; -// hipDrvGraphMemcpyNodeSetParams[('hipGraphNode_t', 'hNode'), ('const HIP_MEMCPY3D*', 'nodeParams')] - case HIP_API_ID_hipDrvGraphMemcpyNodeSetParams: - if (data->args.hipDrvGraphMemcpyNodeSetParams.nodeParams) data->args.hipDrvGraphMemcpyNodeSetParams.nodeParams__val = *(data->args.hipDrvGraphMemcpyNodeSetParams.nodeParams); - break; // hipDrvMemcpy2DUnaligned[('const hip_Memcpy2D*', 'pCopy')] case HIP_API_ID_hipDrvMemcpy2DUnaligned: if (data->args.hipDrvMemcpy2DUnaligned.pCopy) data->args.hipDrvMemcpy2DUnaligned.pCopy__val = *(data->args.hipDrvMemcpy2DUnaligned.pCopy); @@ -6017,6 +6243,9 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipEventSynchronize[('hipEvent_t', 'event')] case HIP_API_ID_hipEventSynchronize: break; +// hipExtGetLastError[] + case HIP_API_ID_hipExtGetLastError: + break; // hipExtGetLinkTypeAndHopCount[('int', 'device1'), ('int', 'device2'), ('unsigned int*', 'linktype'), ('unsigned int*', 'hopcount')] case HIP_API_ID_hipExtGetLinkTypeAndHopCount: if (data->args.hipExtGetLinkTypeAndHopCount.linktype) data->args.hipExtGetLinkTypeAndHopCount.linktype__val = *(data->args.hipExtGetLinkTypeAndHopCount.linktype); @@ -6122,16 +6351,23 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipGetErrorString[] case HIP_API_ID_hipGetErrorString: break; +// hipGetFuncBySymbol[('hipFunction_t*', 'functionPtr'), ('const void*', 'symbolPtr')] + case HIP_API_ID_hipGetFuncBySymbol: + if (data->args.hipGetFuncBySymbol.functionPtr) data->args.hipGetFuncBySymbol.functionPtr__val = *(data->args.hipGetFuncBySymbol.functionPtr); + break; // hipGetLastError[] case HIP_API_ID_hipGetLastError: break; -// hipExtGetLastError[] - case HIP_API_ID_hipExtGetLastError: - break; // hipGetMipmappedArrayLevel[('hipArray_t*', 'levelArray'), ('hipMipmappedArray_const_t', 'mipmappedArray'), ('unsigned int', 'level')] case HIP_API_ID_hipGetMipmappedArrayLevel: if (data->args.hipGetMipmappedArrayLevel.levelArray) data->args.hipGetMipmappedArrayLevel.levelArray__val = *(data->args.hipGetMipmappedArrayLevel.levelArray); break; +// hipGetProcAddress[('const char*', 'symbol'), ('void**', 'pfn'), ('int', 'hipVersion'), ('uint64_t', 'flags'), ('hipDriverProcAddressQueryResult*', 'symbolStatus')] + case HIP_API_ID_hipGetProcAddress: + if (data->args.hipGetProcAddress.symbol) data->args.hipGetProcAddress.symbol__val = *(data->args.hipGetProcAddress.symbol); + if (data->args.hipGetProcAddress.pfn) data->args.hipGetProcAddress.pfn__val = *(data->args.hipGetProcAddress.pfn); + if (data->args.hipGetProcAddress.symbolStatus) data->args.hipGetProcAddress.symbolStatus__val = *(data->args.hipGetProcAddress.symbolStatus); + break; // hipGetSymbolAddress[('void**', 'devPtr'), ('const void*', 'symbol')] case HIP_API_ID_hipGetSymbolAddress: if (data->args.hipGetSymbolAddress.devPtr) data->args.hipGetSymbolAddress.devPtr__val = *(data->args.hipGetSymbolAddress.devPtr); @@ -6227,6 +6463,12 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { if (data->args.hipGraphAddMemsetNode.pDependencies) data->args.hipGraphAddMemsetNode.pDependencies__val = *(data->args.hipGraphAddMemsetNode.pDependencies); if (data->args.hipGraphAddMemsetNode.pMemsetParams) data->args.hipGraphAddMemsetNode.pMemsetParams__val = *(data->args.hipGraphAddMemsetNode.pMemsetParams); break; +// hipGraphAddNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), ('hipGraphNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphAddNode: + if (data->args.hipGraphAddNode.pGraphNode) data->args.hipGraphAddNode.pGraphNode__val = *(data->args.hipGraphAddNode.pGraphNode); + if (data->args.hipGraphAddNode.pDependencies) data->args.hipGraphAddNode.pDependencies__val = *(data->args.hipGraphAddNode.pDependencies); + if (data->args.hipGraphAddNode.nodeParams) data->args.hipGraphAddNode.nodeParams__val = *(data->args.hipGraphAddNode.nodeParams); + break; // hipGraphChildGraphNodeGetGraph[('hipGraphNode_t', 'node'), ('hipGraph_t*', 'pGraph')] case HIP_API_ID_hipGraphChildGraphNodeGetGraph: if (data->args.hipGraphChildGraphNodeGetGraph.pGraph) data->args.hipGraphChildGraphNodeGetGraph.pGraph__val = *(data->args.hipGraphChildGraphNodeGetGraph.pGraph); @@ -6363,10 +6605,15 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { case HIP_API_ID_hipGraphInstantiateWithFlags: if (data->args.hipGraphInstantiateWithFlags.pGraphExec) data->args.hipGraphInstantiateWithFlags.pGraphExec__val = *(data->args.hipGraphInstantiateWithFlags.pGraphExec); break; +// hipGraphInstantiateWithParams[('hipGraphExec_t*', 'pGraphExec'), ('hipGraph_t', 'graph'), ('hipGraphInstantiateParams*', 'instantiateParams')] + case HIP_API_ID_hipGraphInstantiateWithParams: + if (data->args.hipGraphInstantiateWithParams.pGraphExec) data->args.hipGraphInstantiateWithParams.pGraphExec__val = *(data->args.hipGraphInstantiateWithParams.pGraphExec); + if (data->args.hipGraphInstantiateWithParams.instantiateParams) data->args.hipGraphInstantiateWithParams.instantiateParams__val = *(data->args.hipGraphInstantiateWithParams.instantiateParams); + break; // hipGraphKernelNodeCopyAttributes[('hipGraphNode_t', 'hSrc'), ('hipGraphNode_t', 'hDst')] case HIP_API_ID_hipGraphKernelNodeCopyAttributes: break; -// hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), ('hipKernelNodeAttrID', 'attr'), ('hipKernelNodeAttrValue*', 'value')] +// hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), ('hipLaunchAttributeID', 'attr'), ('hipLaunchAttributeValue*', 'value')] case HIP_API_ID_hipGraphKernelNodeGetAttribute: if (data->args.hipGraphKernelNodeGetAttribute.value) data->args.hipGraphKernelNodeGetAttribute.value__val = *(data->args.hipGraphKernelNodeGetAttribute.value); break; @@ -6374,7 +6621,7 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { case HIP_API_ID_hipGraphKernelNodeGetParams: if (data->args.hipGraphKernelNodeGetParams.pNodeParams) data->args.hipGraphKernelNodeGetParams.pNodeParams__val = *(data->args.hipGraphKernelNodeGetParams.pNodeParams); break; -// hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), ('hipKernelNodeAttrID', 'attr'), ('const hipKernelNodeAttrValue*', 'value')] +// hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), ('hipLaunchAttributeID', 'attr'), ('const hipLaunchAttributeValue*', 'value')] case HIP_API_ID_hipGraphKernelNodeSetAttribute: if (data->args.hipGraphKernelNodeSetAttribute.value) data->args.hipGraphKernelNodeSetAttribute.value__val = *(data->args.hipGraphKernelNodeSetAttribute.value); break; @@ -6748,6 +6995,9 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipMemcpy2D[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] case HIP_API_ID_hipMemcpy2D: break; +// hipMemcpy2DArrayToArray[('hipArray_t', 'dst'), ('size_t', 'wOffsetDst'), ('size_t', 'hOffsetDst'), ('hipArray_const_t', 'src'), ('size_t', 'wOffsetSrc'), ('size_t', 'hOffsetSrc'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpy2DArrayToArray: + break; // hipMemcpy2DAsync[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] case HIP_API_ID_hipMemcpy2DAsync: break; @@ -6774,9 +7024,21 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipMemcpyAsync[('void*', 'dst'), ('const void*', 'src'), ('size_t', 'sizeBytes'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] case HIP_API_ID_hipMemcpyAsync: break; +// hipMemcpyAtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] + case HIP_API_ID_hipMemcpyAtoA: + break; +// hipMemcpyAtoD[('hipDeviceptr_t', 'dstDevice'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] + case HIP_API_ID_hipMemcpyAtoD: + break; // hipMemcpyAtoH[('void*', 'dst'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'count')] case HIP_API_ID_hipMemcpyAtoH: break; +// hipMemcpyAtoHAsync[('void*', 'dstHost'), ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyAtoHAsync: + break; +// hipMemcpyDtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('hipDeviceptr_t', 'srcDevice'), ('size_t', 'ByteCount')] + case HIP_API_ID_hipMemcpyDtoA: + break; // hipMemcpyDtoD[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), ('size_t', 'sizeBytes')] case HIP_API_ID_hipMemcpyDtoD: break; @@ -6801,6 +7063,9 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipMemcpyHtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('const void*', 'srcHost'), ('size_t', 'count')] case HIP_API_ID_hipMemcpyHtoA: break; +// hipMemcpyHtoAAsync[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('const void*', 'srcHost'), ('size_t', 'ByteCount'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyHtoAAsync: + break; // hipMemcpyHtoD[('hipDeviceptr_t', 'dst'), ('void*', 'src'), ('size_t', 'sizeBytes')] case HIP_API_ID_hipMemcpyHtoD: break; @@ -6988,6 +7253,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipSetDeviceFlags[('unsigned int', 'flags')] case HIP_API_ID_hipSetDeviceFlags: break; +// hipSetValidDevices[('int*', 'device_arr'), ('int', 'len')] + case HIP_API_ID_hipSetValidDevices: + if (data->args.hipSetValidDevices.device_arr) data->args.hipSetValidDevices.device_arr__val = *(data->args.hipSetValidDevices.device_arr); + break; // hipSetupArgument[('const void*', 'arg'), ('size_t', 'size'), ('size_t', 'offset')] case HIP_API_ID_hipSetupArgument: break; @@ -7005,6 +7274,11 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { // hipStreamBeginCapture[('hipStream_t', 'stream'), ('hipStreamCaptureMode', 'mode')] case HIP_API_ID_hipStreamBeginCapture: break; +// hipStreamBeginCaptureToGraph[('hipStream_t', 'stream'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'dependencies'), ('const hipGraphEdgeData*', 'dependencyData'), ('size_t', 'numDependencies'), ('hipStreamCaptureMode', 'mode')] + case HIP_API_ID_hipStreamBeginCaptureToGraph: + if (data->args.hipStreamBeginCaptureToGraph.dependencies) data->args.hipStreamBeginCaptureToGraph.dependencies__val = *(data->args.hipStreamBeginCaptureToGraph.dependencies); + if (data->args.hipStreamBeginCaptureToGraph.dependencyData) data->args.hipStreamBeginCaptureToGraph.dependencyData__val = *(data->args.hipStreamBeginCaptureToGraph.dependencyData); + break; // hipStreamCreate[('hipStream_t*', 'stream')] case HIP_API_ID_hipStreamCreate: if (data->args.hipStreamCreate.stream) data->args.hipStreamCreate.stream__val = *(data->args.hipStreamCreate.stream); @@ -7083,6 +7357,16 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { if (data->args.hipTexRefGetAddress.dev_ptr) data->args.hipTexRefGetAddress.dev_ptr__val = *(data->args.hipTexRefGetAddress.dev_ptr); if (data->args.hipTexRefGetAddress.texRef) data->args.hipTexRefGetAddress.texRef__val = *(data->args.hipTexRefGetAddress.texRef); break; +// hipTexRefGetArray[('hipArray_t*', 'pArray'), ('const textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetArray: + if (data->args.hipTexRefGetArray.pArray) data->args.hipTexRefGetArray.pArray__val = *(data->args.hipTexRefGetArray.pArray); + if (data->args.hipTexRefGetArray.texRef) data->args.hipTexRefGetArray.texRef__val = *(data->args.hipTexRefGetArray.texRef); + break; +// hipTexRefGetBorderColor[('float*', 'pBorderColor'), ('const textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetBorderColor: + if (data->args.hipTexRefGetBorderColor.pBorderColor) data->args.hipTexRefGetBorderColor.pBorderColor__val = *(data->args.hipTexRefGetBorderColor.pBorderColor); + if (data->args.hipTexRefGetBorderColor.texRef) data->args.hipTexRefGetBorderColor.texRef__val = *(data->args.hipTexRefGetBorderColor.texRef); + break; // hipTexRefGetFlags[('unsigned int*', 'pFlags'), ('const textureReference*', 'texRef')] case HIP_API_ID_hipTexRefGetFlags: if (data->args.hipTexRefGetFlags.pFlags) data->args.hipTexRefGetFlags.pFlags__val = *(data->args.hipTexRefGetFlags.pFlags); @@ -7636,20 +7920,6 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", ctx="; roctracer::hip_support::detail::operator<<(oss, data->args.hipDrvGraphAddMemsetNode.ctx); oss << ")"; break; - case HIP_API_ID_hipDrvGraphMemcpyNodeGetParams: - oss << "hipDrvGraphMemcpyNodeGetParams("; - oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipDrvGraphMemcpyNodeGetParams.hNode); - if (data->args.hipDrvGraphMemcpyNodeGetParams.nodeParams == NULL) oss << ", nodeParams=NULL"; - else { oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipDrvGraphMemcpyNodeGetParams.nodeParams__val); } - oss << ")"; - break; - case HIP_API_ID_hipDrvGraphMemcpyNodeSetParams: - oss << "hipDrvGraphMemcpyNodeSetParams("; - oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipDrvGraphMemcpyNodeSetParams.hNode); - if (data->args.hipDrvGraphMemcpyNodeSetParams.nodeParams == NULL) oss << ", nodeParams=NULL"; - else { oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipDrvGraphMemcpyNodeSetParams.nodeParams__val); } - oss << ")"; - break; case HIP_API_ID_hipDrvMemcpy2DUnaligned: oss << "hipDrvMemcpy2DUnaligned("; if (data->args.hipDrvMemcpy2DUnaligned.pCopy == NULL) oss << "pCopy=NULL"; @@ -7721,6 +7991,10 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << "event="; roctracer::hip_support::detail::operator<<(oss, data->args.hipEventSynchronize.event); oss << ")"; break; + case HIP_API_ID_hipExtGetLastError: + oss << "hipExtGetLastError("; + oss << ")"; + break; case HIP_API_ID_hipExtGetLinkTypeAndHopCount: oss << "hipExtGetLinkTypeAndHopCount("; oss << "device1="; roctracer::hip_support::detail::operator<<(oss, data->args.hipExtGetLinkTypeAndHopCount.device1); @@ -7929,12 +8203,15 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << "hipGetErrorString("; oss << ")"; break; - case HIP_API_ID_hipGetLastError: - oss << "hipGetLastError("; + case HIP_API_ID_hipGetFuncBySymbol: + oss << "hipGetFuncBySymbol("; + if (data->args.hipGetFuncBySymbol.functionPtr == NULL) oss << "functionPtr=NULL"; + else { oss << "functionPtr="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetFuncBySymbol.functionPtr__val); } + oss << ", symbolPtr="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetFuncBySymbol.symbolPtr); oss << ")"; break; - case HIP_API_ID_hipExtGetLastError: - oss << "hipExtGetLastError("; + case HIP_API_ID_hipGetLastError: + oss << "hipGetLastError("; oss << ")"; break; case HIP_API_ID_hipGetMipmappedArrayLevel: @@ -7945,6 +8222,18 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", level="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetMipmappedArrayLevel.level); oss << ")"; break; + case HIP_API_ID_hipGetProcAddress: + oss << "hipGetProcAddress("; + if (data->args.hipGetProcAddress.symbol == NULL) oss << "symbol=NULL"; + else { oss << "symbol="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetProcAddress.symbol__val); } + if (data->args.hipGetProcAddress.pfn == NULL) oss << ", pfn=NULL"; + else { oss << ", pfn="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetProcAddress.pfn__val); } + oss << ", hipVersion="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetProcAddress.hipVersion); + oss << ", flags="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetProcAddress.flags); + if (data->args.hipGetProcAddress.symbolStatus == NULL) oss << ", symbolStatus=NULL"; + else { oss << ", symbolStatus="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGetProcAddress.symbolStatus__val); } + oss << ")"; + break; case HIP_API_ID_hipGetSymbolAddress: oss << "hipGetSymbolAddress("; if (data->args.hipGetSymbolAddress.devPtr == NULL) oss << "devPtr=NULL"; @@ -8151,6 +8440,18 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da else { oss << ", pMemsetParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddMemsetNode.pMemsetParams__val); } oss << ")"; break; + case HIP_API_ID_hipGraphAddNode: + oss << "hipGraphAddNode("; + if (data->args.hipGraphAddNode.pGraphNode == NULL) oss << "pGraphNode=NULL"; + else { oss << "pGraphNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddNode.pGraphNode__val); } + oss << ", graph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddNode.graph); + if (data->args.hipGraphAddNode.pDependencies == NULL) oss << ", pDependencies=NULL"; + else { oss << ", pDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddNode.pDependencies__val); } + oss << ", numDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddNode.numDependencies); + if (data->args.hipGraphAddNode.nodeParams == NULL) oss << ", nodeParams=NULL"; + else { oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddNode.nodeParams__val); } + oss << ")"; + break; case HIP_API_ID_hipGraphChildGraphNodeGetGraph: oss << "hipGraphChildGraphNodeGetGraph("; oss << "node="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphChildGraphNodeGetGraph.node); @@ -8423,6 +8724,15 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", flags="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphInstantiateWithFlags.flags); oss << ")"; break; + case HIP_API_ID_hipGraphInstantiateWithParams: + oss << "hipGraphInstantiateWithParams("; + if (data->args.hipGraphInstantiateWithParams.pGraphExec == NULL) oss << "pGraphExec=NULL"; + else { oss << "pGraphExec="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphInstantiateWithParams.pGraphExec__val); } + oss << ", graph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphInstantiateWithParams.graph); + if (data->args.hipGraphInstantiateWithParams.instantiateParams == NULL) oss << ", instantiateParams=NULL"; + else { oss << ", instantiateParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphInstantiateWithParams.instantiateParams__val); } + oss << ")"; + break; case HIP_API_ID_hipGraphKernelNodeCopyAttributes: oss << "hipGraphKernelNodeCopyAttributes("; oss << "hSrc="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphKernelNodeCopyAttributes.hSrc); @@ -9215,6 +9525,19 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", kind="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2D.kind); oss << ")"; break; + case HIP_API_ID_hipMemcpy2DArrayToArray: + oss << "hipMemcpy2DArrayToArray("; + oss << "dst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.dst); + oss << ", wOffsetDst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.wOffsetDst); + oss << ", hOffsetDst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.hOffsetDst); + oss << ", src="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.src); + oss << ", wOffsetSrc="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.wOffsetSrc); + oss << ", hOffsetSrc="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.hOffsetSrc); + oss << ", width="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.width); + oss << ", height="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.height); + oss << ", kind="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DArrayToArray.kind); + oss << ")"; + break; case HIP_API_ID_hipMemcpy2DAsync: oss << "hipMemcpy2DAsync("; oss << "dst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2DAsync.dst); @@ -9299,6 +9622,23 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAsync.stream); oss << ")"; break; + case HIP_API_ID_hipMemcpyAtoA: + oss << "hipMemcpyAtoA("; + oss << "dstArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoA.dstArray); + oss << ", dstOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoA.dstOffset); + oss << ", srcArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoA.srcArray); + oss << ", srcOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoA.srcOffset); + oss << ", ByteCount="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoA.ByteCount); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAtoD: + oss << "hipMemcpyAtoD("; + oss << "dstDevice="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoD.dstDevice); + oss << ", srcArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoD.srcArray); + oss << ", srcOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoD.srcOffset); + oss << ", ByteCount="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoD.ByteCount); + oss << ")"; + break; case HIP_API_ID_hipMemcpyAtoH: oss << "hipMemcpyAtoH("; oss << "dst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoH.dst); @@ -9307,6 +9647,23 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", count="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoH.count); oss << ")"; break; + case HIP_API_ID_hipMemcpyAtoHAsync: + oss << "hipMemcpyAtoHAsync("; + oss << "dstHost="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoHAsync.dstHost); + oss << ", srcArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoHAsync.srcArray); + oss << ", srcOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoHAsync.srcOffset); + oss << ", ByteCount="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoHAsync.ByteCount); + oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyAtoHAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoA: + oss << "hipMemcpyDtoA("; + oss << "dstArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyDtoA.dstArray); + oss << ", dstOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyDtoA.dstOffset); + oss << ", srcDevice="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyDtoA.srcDevice); + oss << ", ByteCount="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyDtoA.ByteCount); + oss << ")"; + break; case HIP_API_ID_hipMemcpyDtoD: oss << "hipMemcpyDtoD("; oss << "dst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyDtoD.dst); @@ -9374,6 +9731,15 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", count="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoA.count); oss << ")"; break; + case HIP_API_ID_hipMemcpyHtoAAsync: + oss << "hipMemcpyHtoAAsync("; + oss << "dstArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoAAsync.dstArray); + oss << ", dstOffset="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoAAsync.dstOffset); + oss << ", srcHost="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoAAsync.srcHost); + oss << ", ByteCount="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoAAsync.ByteCount); + oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoAAsync.stream); + oss << ")"; + break; case HIP_API_ID_hipMemcpyHtoD: oss << "hipMemcpyHtoD("; oss << "dst="; roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpyHtoD.dst); @@ -9797,6 +10163,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << "flags="; roctracer::hip_support::detail::operator<<(oss, data->args.hipSetDeviceFlags.flags); oss << ")"; break; + case HIP_API_ID_hipSetValidDevices: + oss << "hipSetValidDevices("; + if (data->args.hipSetValidDevices.device_arr == NULL) oss << "device_arr=NULL"; + else { oss << "device_arr="; roctracer::hip_support::detail::operator<<(oss, data->args.hipSetValidDevices.device_arr__val); } + oss << ", len="; roctracer::hip_support::detail::operator<<(oss, data->args.hipSetValidDevices.len); + oss << ")"; + break; case HIP_API_ID_hipSetupArgument: oss << "hipSetupArgument("; oss << "arg="; roctracer::hip_support::detail::operator<<(oss, data->args.hipSetupArgument.arg); @@ -9836,6 +10209,18 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", mode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCapture.mode); oss << ")"; break; + case HIP_API_ID_hipStreamBeginCaptureToGraph: + oss << "hipStreamBeginCaptureToGraph("; + oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.stream); + oss << ", graph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.graph); + if (data->args.hipStreamBeginCaptureToGraph.dependencies == NULL) oss << ", dependencies=NULL"; + else { oss << ", dependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.dependencies__val); } + if (data->args.hipStreamBeginCaptureToGraph.dependencyData == NULL) oss << ", dependencyData=NULL"; + else { oss << ", dependencyData="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.dependencyData__val); } + oss << ", numDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.numDependencies); + oss << ", mode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamBeginCaptureToGraph.mode); + oss << ")"; + break; case HIP_API_ID_hipStreamCreate: oss << "hipStreamCreate("; if (data->args.hipStreamCreate.stream == NULL) oss << "stream=NULL"; @@ -9989,6 +10374,22 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da else { oss << ", texRef="; roctracer::hip_support::detail::operator<<(oss, data->args.hipTexRefGetAddress.texRef__val); } oss << ")"; break; + case HIP_API_ID_hipTexRefGetArray: + oss << "hipTexRefGetArray("; + if (data->args.hipTexRefGetArray.pArray == NULL) oss << "pArray=NULL"; + else { oss << "pArray="; roctracer::hip_support::detail::operator<<(oss, data->args.hipTexRefGetArray.pArray__val); } + if (data->args.hipTexRefGetArray.texRef == NULL) oss << ", texRef=NULL"; + else { oss << ", texRef="; roctracer::hip_support::detail::operator<<(oss, data->args.hipTexRefGetArray.texRef__val); } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetBorderColor: + oss << "hipTexRefGetBorderColor("; + if (data->args.hipTexRefGetBorderColor.pBorderColor == NULL) oss << "pBorderColor=NULL"; + else { oss << "pBorderColor="; roctracer::hip_support::detail::operator<<(oss, data->args.hipTexRefGetBorderColor.pBorderColor__val); } + if (data->args.hipTexRefGetBorderColor.texRef == NULL) oss << ", texRef=NULL"; + else { oss << ", texRef="; roctracer::hip_support::detail::operator<<(oss, data->args.hipTexRefGetBorderColor.texRef__val); } + oss << ")"; + break; case HIP_API_ID_hipTexRefGetFlags: oss << "hipTexRefGetFlags("; if (data->args.hipTexRefGetFlags.pFlags == NULL) oss << "pFlags=NULL"; diff --git a/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h b/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h index d201ab517c9b..307e75c21e76 100644 --- a/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h +++ b/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h @@ -34,6 +34,7 @@ enum HipVdiOpId { // Types of ROCclr commands enum HipVdiCommandKind { kHipVdiCommandKernel = 0x11F0, + kHipVdiCommandTask = 0x11F1, kHipVdiMemcpyDeviceToHost = 0x11F3, kHipHipVdiMemcpyHostToDevice = 0x11F4, kHipVdiMemcpyDeviceToDevice = 0x11F5, @@ -41,7 +42,7 @@ enum HipVdiCommandKind { kHipVdiMemcpyHostToDeviceRect = 0x1202, kHipVdiMemcpyDeviceToDeviceRect = 0x1203, kHipVdiFillMemory = 0x1207, -}; +}; /** * @brief Initializes activity callback diff --git a/third_party/amd/backend/include/hip/amd_detail/host_defines.h b/third_party/amd/backend/include/hip/amd_detail/host_defines.h index 0fad2b47042b..e7e8364969f7 100644 --- a/third_party/amd/backend/include/hip/amd_detail/host_defines.h +++ b/third_party/amd/backend/include/hip/amd_detail/host_defines.h @@ -127,6 +127,10 @@ template struct is_trivial : public integral_constant { }; + + +template struct conditional { using type = T; }; +template struct conditional { using type = F; }; } typedef __hip_internal::uint8_t __hip_uint8_t; typedef __hip_internal::uint16_t __hip_uint16_t; diff --git a/third_party/amd/backend/include/hip/hip_ext.h b/third_party/amd/backend/include/hip/hip_ext.h index 5d5d9b6fa26b..319f5694d021 100644 --- a/third_party/amd/backend/include/hip/hip_ext.h +++ b/third_party/amd/backend/include/hip/hip_ext.h @@ -64,6 +64,8 @@ THE SOFTWARE. * Currently, timing between startEvent and stopEvent does not include the time it takes to perform * a system scope release/cache flush - only the time it takes to issues writes to cache. * + * @note For this HIP API, the flag 'hipExtAnyOrderLaunch' is not supported on AMD GFX9xx boards. + * */ HIP_PUBLIC_API extern "C" hipError_t hipExtModuleLaunchKernel(hipFunction_t f, uint32_t globalWorkSizeX, @@ -78,6 +80,7 @@ HIP_PUBLIC_API * @brief This HIP API is deprecated, please use hipExtModuleLaunchKernel() instead. * */ +DEPRECATED("use hipExtModuleLaunchKernel instead") HIP_PUBLIC_API extern "C" hipError_t hipHccModuleLaunchKernel(hipFunction_t f, uint32_t globalWorkSizeX, uint32_t globalWorkSizeY, uint32_t globalWorkSizeZ, @@ -85,8 +88,7 @@ extern "C" hipError_t hipHccModuleLaunchKernel(hipFunction_t f, uint32_t globalW uint32_t localWorkSizeZ, size_t sharedMemBytes, hipStream_t hStream, void** kernelParams, void** extra, hipEvent_t startEvent __dparm(NULL), - hipEvent_t stopEvent __dparm(NULL)) - __attribute__((deprecated("use hipExtModuleLaunchKernel instead"))); + hipEvent_t stopEvent __dparm(NULL)); #if defined(__cplusplus) diff --git a/third_party/amd/backend/include/hip/hip_fp8.h b/third_party/amd/backend/include/hip/hip_fp8.h new file mode 100644 index 000000000000..82f47afcba08 --- /dev/null +++ b/third_party/amd/backend/include/hip/hip_fp8.h @@ -0,0 +1,33 @@ +/* +Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_FP8_H +#define HIP_INCLUDE_HIP_HIP_FP8_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +// We only have fnuz defs for now, which are not supported by other platforms +#include +#endif + +#endif // HIP_INCLUDE_HIP_HIP_FP8_H diff --git a/third_party/amd/backend/include/hip/hip_runtime_api.h b/third_party/amd/backend/include/hip/hip_runtime_api.h index 498173bbb158..0323d77d5117 100644 --- a/third_party/amd/backend/include/hip/hip_runtime_api.h +++ b/third_party/amd/backend/include/hip/hip_runtime_api.h @@ -102,7 +102,7 @@ typedef struct hipDeviceProp_t { char luid[8]; ///< 8-byte unique identifier. Only valid on windows unsigned int luidDeviceNodeMask; ///< LUID node mask size_t totalGlobalMem; ///< Size of global memory region (in bytes). - size_t sharedMemPerBlock; ///< Size of shared memory region (in bytes). + size_t sharedMemPerBlock; ///< Size of shared memory per block (in bytes). int regsPerBlock; ///< Registers per block. int warpSize; ///< Warp size. size_t memPitch; ///< Maximum pitch in bytes allowed by memory copies @@ -111,7 +111,8 @@ typedef struct hipDeviceProp_t { int maxThreadsDim[3]; ///< Max number of threads in each dimension (XYZ) of a block. int maxGridSize[3]; ///< Max grid dimensions (XYZ). int clockRate; ///< Max clock frequency of the multiProcessors in khz. - size_t totalConstMem; ///< Size of shared memory region (in bytes). + size_t totalConstMem; ///< Size of shared constant memory region on the device + ///< (in bytes). int major; ///< Major compute capability. On HCC, this is an approximation and features may ///< differ from CUDA CC. See the arch feature flags for portable ways to query ///< feature caps. @@ -538,6 +539,12 @@ typedef enum hipDeviceAttribute_t { // Extended attributes for vendors } hipDeviceAttribute_t; +typedef enum hipDriverProcAddressQueryResult { + HIP_GET_PROC_ADDRESS_SUCCESS = 0, + HIP_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1, + HIP_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2 +} hipDriverProcAddressQueryResult; + enum hipComputeMode { hipComputeModeDefault = 0, hipComputeModeExclusive = 1, @@ -740,6 +747,9 @@ enum hipLimit_t { /** Memory allocated will be uncached. */ #define hipDeviceMallocUncached 0x3 +/** Memory allocated will be contiguous. */ +#define hipDeviceMallocContiguous 0x4 + //Flags that can be used with hipHostRegister. /** Memory is Mapped and Portable.*/ #define hipHostRegisterDefault 0x0 @@ -798,6 +808,8 @@ enum hipLimit_t { /** Implicit stream per application thread.*/ #define hipStreamPerThread ((hipStream_t)2) +#define hipStreamLegacy ((hipStream_t)1) + // Indicates that the external memory object is a dedicated resource #define hipExternalMemoryDedicated 0x1 /** @@ -973,7 +985,8 @@ typedef struct hipMemPoolProps { * Windows-specific LPSECURITYATTRIBUTES required when @p hipMemHandleTypeWin32 is specified */ void* win32SecurityAttributes; - unsigned char reserved[64]; ///< Reserved for future use, must be 0 + size_t maxSize; ///< Maximum pool size. When set to 0, defaults to a system dependent value + unsigned char reserved[56]; ///< Reserved for future use, must be 0 } hipMemPoolProps; /** * Opaque data structure for exporting a pool allocation @@ -1269,13 +1282,7 @@ typedef struct hipMemAllocNodeParams { void* dptr; ///< Returned device address of the allocation } hipMemAllocNodeParams; -/** - * Kernel node attributeID - */ -typedef enum hipKernelNodeAttrID { - hipKernelNodeAttributeAccessPolicyWindow = 1, - hipKernelNodeAttributeCooperative = 2, -} hipKernelNodeAttrID; + typedef enum hipAccessProperty { hipAccessPropertyNormal = 0, hipAccessPropertyStreaming = 1, @@ -1288,10 +1295,39 @@ typedef struct hipAccessPolicyWindow { hipAccessProperty missProp; size_t num_bytes; } hipAccessPolicyWindow; -typedef union hipKernelNodeAttrValue { - hipAccessPolicyWindow accessPolicyWindow; - int cooperative; -} hipKernelNodeAttrValue; + +/** + * Launch Attribute ID + */ +typedef enum hipLaunchAttributeID { + hipLaunchAttributeAccessPolicyWindow = 1, /**< Valid for Streams, graph nodes, launches*/ + hipLaunchAttributeCooperative = 2, /**< Valid for graph nodes, launches */ + hipLaunchAttributePriority = 8, /**< Valid for graph node, streams, launches */ +} hipLaunchAttributeID; + +/** + * Launch Attribute Value + */ +typedef union hipLaunchAttributeValue { + hipAccessPolicyWindow accessPolicyWindow; /**< Value of launch attribute:: + hipLaunchAttributePolicyWindow. */ + int cooperative; /**< Value of launch attribute ::hipLaunchAttributeCooperative */ + int priority; /**< Value of launch attribute :: hipLaunchAttributePriority. Execution + priority of kernel. */ +} hipLaunchAttributeValue; + +/** + * Kernel node attributeID + */ +#define hipKernelNodeAttrID hipLaunchAttributeID +#define hipKernelNodeAttributeAccessPolicyWindow hipLaunchAttributeAccessPolicyWindow +#define hipKernelNodeAttributeCooperative hipLaunchAttributeCooperative +#define hipKernelNodeAttributePriority hipLaunchAttributePriority + +/** + * Kernel node attribute value + */ +#define hipKernelNodeAttrValue hipLaunchAttributeValue /** * Memset node params @@ -1383,6 +1419,34 @@ enum hipGraphDebugDotFlags { hipGraphDebugDotFlagsHandles = 1 << 10 /**< Adds node handles and every kernel function handle to output */ }; + +/** +* hipGraphInstantiateWithParams results +*/ +typedef enum hipGraphInstantiateResult { + hipGraphInstantiateSuccess = 0, /**< Instantiation Success */ + hipGraphInstantiateError = 1, /**< Instantiation failed for an + unexpected reason which is described in the return value of the function */ + hipGraphInstantiateInvalidStructure = 2, /**< Instantiation failed due + to invalid structure, such as cycles */ + hipGraphInstantiateNodeOperationNotSupported = 3, /**< Instantiation for device launch failed + because the graph contained an unsupported operation */ + hipGraphInstantiateMultipleDevicesNotSupported = 4, /**< Instantiation for device launch failed + due to the nodes belonging to different contexts */ +}hipGraphInstantiateResult; + +/** + * Graph Instantiation parameters +*/ +typedef struct hipGraphInstantiateParams { + hipGraphNode_t errNode_out; /**< The node which caused instantiation to fail, if any*/ + unsigned long long flags; /**< Instantiation flags */ + hipGraphInstantiateResult result_out; /**< Whether instantiation was successful. + If it failed, the reason why */ + hipStream_t uploadStream; /**< Upload stream */ +} hipGraphInstantiateParams; + + /** * Memory allocation properties */ @@ -1557,6 +1621,44 @@ typedef struct hipGraphNodeParams { long long reserved2; } hipGraphNodeParams; + +/** + * This port activates when the kernel has finished executing. + */ +#define hipGraphKernelNodePortDefault 0 + +/** + * This port activates when all blocks of the kernel have begun execution. + */ +#define hipGraphKernelNodePortLaunchCompletion 2 + +/** + * This port activates when all blocks of the kernel have performed + * hipTriggerProgrammaticLaunchCompletion() or have terminated. + * It must be used with edge type hipGraphDependencyTypeProgrammatic. + */ +#define hipGraphKernelNodePortProgrammatic 1 + +typedef enum hipGraphDependencyType { + hipGraphDependencyTypeDefault = 0, + hipGraphDependencyTypeProgrammatic = 1 +}hipGraphDependencyType; + +typedef struct hipGraphEdgeData { + unsigned char + from_port; ///< This indicates when the dependency is triggered from the upstream node on the + ///< edge. The meaning is specfic to the node type. A value of 0 in all cases + ///< means full completion of the upstream node, with memory visibility to the + ///< downstream node or portion thereof (indicated by to_port). Only kernel nodes + ///< define non-zero ports. A kernel node can use the following output port types: + ///< hipGraphKernelNodePortDefault, hipGraphKernelNodePortProgrammatic, or + ///< hipGraphKernelNodePortLaunchCompletion. + unsigned char reserved[5]; ///< These bytes are unused and must be zeroed + unsigned char + to_port; ///< Currently no node types define non-zero ports. This field must be set to zero. + unsigned char type; ///< This should be populated with a value from hipGraphDependencyType +} hipGraphEdgeData; + // Doxygen end group GlobalDefs /** * @} @@ -1585,6 +1687,7 @@ typedef struct hipGraphNodeParams { */ // TODO-ctx - more description on error codes. hipError_t hipInit(unsigned int flags); + /** * @brief Returns the approximate HIP driver version. * @@ -1755,6 +1858,18 @@ hipError_t hipDeviceReset(void); * @see #hipGetDevice, #hipGetDeviceCount */ hipError_t hipSetDevice(int deviceId); +/** + * @brief Set a list of devices that can be used. + * + * @param[in] device_arr List of devices to try + * @param[in] len Number of devices in specified list + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see #hipGetDevice, #hipGetDeviceCount. #hipSetDevice. #hipGetDeviceProperties. #hipSetDeviceFlags. #hipChooseDevice + * + * */ +hipError_t hipSetValidDevices(int* device_arr, int len); /** * @brief Return the default device id for the calling host thread. * @@ -2100,7 +2215,7 @@ hipError_t hipIpcGetEventHandle(hipIpcEventHandle_t* handle, hipEvent_t event); /** * @brief Opens an interprocess event handles. * - * Opens an interprocess event handle exported from another process with cudaIpcGetEventHandle. The returned + * Opens an interprocess event handle exported from another process with hipIpcGetEventHandle. The returned * hipEvent_t behaves like a locally created event with the hipEventDisableTiming flag specified. This event * need be freed with hipEventDestroy. Operations on the imported event after the exported event has been freed * with hipEventDestroy will result in undefined behavior. If the function is called within the same process where @@ -2276,7 +2391,7 @@ hipError_t hipDrvGetErrorString(hipError_t hipError, const char** errorString); * Create a new asynchronous stream. @p stream returns an opaque handle that can be used to * reference the newly created stream in subsequent hipStream* commands. The stream is allocated on * the heap and will remain allocated even if the handle goes out-of-scope. To release the memory - * used by the stream, applicaiton must call hipStreamDestroy. + * used by the stream, application must call hipStreamDestroy. * * @return #hipSuccess, #hipErrorInvalidValue * @@ -2293,7 +2408,7 @@ hipError_t hipStreamCreate(hipStream_t* stream); * Create a new asynchronous stream. @p stream returns an opaque handle that can be used to * reference the newly created stream in subsequent hipStream* commands. The stream is allocated on * the heap and will remain allocated even if the handle goes out-of-scope. To release the memory - * used by the stream, applicaiton must call hipStreamDestroy. Flags controls behavior of the + * used by the stream, application must call hipStreamDestroy. Flags controls behavior of the * stream. See #hipStreamDefault, #hipStreamNonBlocking. * * @@ -2311,7 +2426,7 @@ hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags); * Create a new asynchronous stream with the specified priority. @p stream returns an opaque handle * that can be used to reference the newly created stream in subsequent hipStream* commands. The * stream is allocated on the heap and will remain allocated even if the handle goes out-of-scope. - * To release the memory used by the stream, applicaiton must call hipStreamDestroy. Flags controls + * To release the memory used by the stream, application must call hipStreamDestroy. Flags controls * behavior of the stream. See #hipStreamDefault, #hipStreamNonBlocking. * * @@ -2329,7 +2444,7 @@ hipError_t hipStreamCreateWithPriority(hipStream_t* stream, unsigned int flags, * and greatest stream priority respectively. Stream priorities follow a convention where lower numbers * imply greater priorities. The range of meaningful stream priorities is given by * [*greatestPriority, *leastPriority]. If the user attempts to create a stream with a priority value - * that is outside the the meaningful range as specified by this API, the priority is automatically + * that is outside the meaningful range as specified by this API, the priority is automatically * clamped to within the valid range. */ hipError_t hipDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority); @@ -2401,8 +2516,8 @@ hipError_t hipStreamSynchronize(hipStream_t stream); * All future work submitted to @p stream will wait until @p event reports completion before * beginning execution. * - * This function only waits for commands in the current stream to complete. Notably,, this function - * does not impliciy wait for commands in the default stream to complete, even if the specified + * This function only waits for commands in the current stream to complete. Notably, this function + * does not implicitly wait for commands in the default stream to complete, even if the specified * stream is created with hipStreamNonBlocking = 0. * * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamCreateWithPriority, hipStreamSynchronize, hipStreamDestroy @@ -2688,7 +2803,7 @@ hipError_t hipEventCreate(hipEvent_t* event); * * If hipEventRecord() has been previously called on this event, then this call will overwrite any * existing state in event. - * + * * If this function is called on an event that is currently being recorded, results are undefined * - either outstanding recording may save state into the event, and the order is not guaranteed. * @@ -2730,7 +2845,6 @@ hipError_t hipEventDestroy(hipEvent_t event); * If hipEventRecord() has not been called on @p event, this function returns #hipSuccess when no * event is captured. * - * This function needs to support hipEventBlockingSync parameter. * * @param[in] event Event on which to wait. * @@ -3252,7 +3366,7 @@ hipError_t hipStreamAttachMemAsync(hipStream_t stream, * * Inserts a memory allocation operation into @p stream. * A pointer to the allocated memory is returned immediately in *dptr. - * The allocation must not be accessed until the the allocation operation completes. + * The allocation must not be accessed until the allocation operation completes. * The allocation comes from the memory pool associated with the stream's device. * * @note The default memory pool of a device contains device memory from that device. @@ -3504,7 +3618,7 @@ hipError_t hipMemPoolDestroy(hipMemPool_t mem_pool); * * Inserts an allocation operation into @p stream. * A pointer to the allocated memory is returned immediately in @p dev_ptr. - * The allocation must not be accessed until the the allocation operation completes. + * The allocation must not be accessed until the allocation operation completes. * The allocation comes from the specified memory pool. * * @note The specified memory pool may be from a device different than that of the specified @p stream. @@ -3915,6 +4029,68 @@ hipError_t hipMemcpyDtoH(void* dst, hipDeviceptr_t src, size_t sizeBytes); * hipMemHostAlloc, hipMemHostGetDevicePointer */ hipError_t hipMemcpyDtoD(hipDeviceptr_t dst, hipDeviceptr_t src, size_t sizeBytes); +/** + * @brief Copies from one 1D array to device memory. + * + * @param[out] dstDevice Destination device pointer + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, #hipErrorInvalidContext, + * #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, + * hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpyAtoA, + * hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, + * hipMemcpyDtoDAsync, hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, hipMemGetInfo, + * hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoD(hipDeviceptr_t dstDevice, hipArray_t srcArray, size_t srcOffset, + size_t ByteCount); +/** + * @brief Copies from device memory to a 1D array. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcDevice Source device pointer + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, #hipErrorInvalidContext, + * #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, + * hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpyAtoA, + * hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, + * hipMemcpyDtoDAsync, hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, hipMemGetInfo, + * hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoA(hipArray_t dstArray, size_t dstOffset, hipDeviceptr_t srcDevice, + size_t ByteCount); + +/** + * @brief Copies from one 1D array to another. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, #hipErrorInvalidContext, + * #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, + * hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpyAtoA, + * hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, + * hipMemcpyDtoDAsync, hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, hipMemGetInfo, + * hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoA(hipArray_t dstArray, size_t dstOffset, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount); /** * @brief Copy data from Host to Device asynchronously * @@ -3973,7 +4149,48 @@ hipError_t hipMemcpyDtoHAsync(void* dst, hipDeviceptr_t src, size_t sizeBytes, h */ hipError_t hipMemcpyDtoDAsync(hipDeviceptr_t dst, hipDeviceptr_t src, size_t sizeBytes, hipStream_t stream); - +/** + * @brief Copies from one 1D array to host memory. + * + * @param[out] dstHost Destination pointer + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, #hipErrorInvalidContext, + * #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, + * hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpyAtoA, + * hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, + * hipMemcpyDtoDAsync, hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, hipMemGetInfo, + * hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoHAsync(void* dstHost, hipArray_t srcArray, size_t srcOffset, + size_t ByteCount, hipStream_t stream); +/** + * @brief Copies from host memory to a 1D array. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcHost Source host pointer + * @param[in] ByteCount Size of memory copy in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, #hipErrorInvalidContext, + * #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, + * hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpyAtoA, + * hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, + * hipMemcpyDtoDAsync, hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, hipMemGetInfo, + * hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyHtoAAsync(hipArray_t dstArray, size_t dstOffset, const void* srcHost, + size_t ByteCount, hipStream_t stream); /** * @brief Returns a global pointer from a module. * Returns in *dptr and *bytes the pointer and size of the global of name name located in module hmod. @@ -4002,6 +4219,8 @@ hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes, */ hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol); + + /** * @brief Gets the size of the given symbol on the device. * @@ -4013,14 +4232,38 @@ hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol); */ hipError_t hipGetSymbolSize(size_t* size, const void* symbol); +/** + * @brief Gets the pointer of requested HIP driver function. + * + * @param[in] symbol The Symbol name of the driver function to request. + * @param[out] pfn Output pointer to the requested driver function. + * @param[in] hipVersion The HIP version for the requested driver function symbol. + * HIP version is defined as 100*version_major + version_minor. For example, in HIP 6.1, the + * hipversion is 601, for the symbol function "hipGetDeviceProperties", the specified hipVersion 601 + * is greater or equal to the version 600, the symbol function will be handle properly as backend + * compatible function. + * + * @param[in] flags Currently only default flag is suppported. + * @param[out] symbolStatus Optional enumeration for returned status of searching for symbol driver + * function based on the input hipVersion. + * + * Returns hipSuccess if the returned pfn is addressed to the pointer of found driver function. + * + * @return #hipSuccess, #hipErrorInvalidValue. + */ +hipError_t hipGetProcAddress(const char* symbol, void** pfn, int hipVersion, uint64_t flags, + hipDriverProcAddressQueryResult* symbolStatus); + /** * @brief Copies data to the given symbol on the device. * Symbol HIP APIs allow a kernel to define a device-side data symbol which can be accessed on * the host side. The symbol can be in __constant or device space. * Note that the symbol name needs to be encased in the HIP_SYMBOL macro. * This also applies to hipMemcpyFromSymbol, hipGetSymbolAddress, and hipGetSymbolSize. - * For detail usage, see the example at - * https://github.com/ROCm/HIP/blob/develop/docs/user_guide/hip_porting_guide.md + * For detailed usage, see the + * memcpyToSymbol example + * in the HIP Porting Guide. + * * * @param[out] symbol pointer to the device symbole * @param[in] src pointer to the source address @@ -4520,6 +4763,27 @@ hipError_t hipMemcpy2DToArray(hipArray_t dst, size_t wOffset, size_t hOffset, co hipError_t hipMemcpy2DToArrayAsync(hipArray_t dst, size_t wOffset, size_t hOffset, const void* src, size_t spitch, size_t width, size_t height, hipMemcpyKind kind, hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] wOffsetDst Destination starting X offset + * @param[in] hOffsetDst Destination starting Y offset + * @param[in] src Source memory address + * @param[in] wOffsetSrc Source starting X offset + * @param[in] hOffsetSrc Source starting Y offset (columns in bytes) + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2D, hipMemcpyFromArray, hipMemcpyToSymbol, + * hipMemcpyAsync + */ +hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, size_t hOffsetDst, + hipArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, + size_t width, size_t height, hipMemcpyKind kind); /** * @brief Copies data between host and device. * @@ -4734,7 +4998,7 @@ hipError_t hipDeviceDisablePeerAccess(int peerDeviceId); * @param [out] psize - Size of allocation * @param [in] dptr- Device Pointer * - * @returns #hipSuccess, #hipErrorInvalidDevicePointer + * @returns #hipSuccess, #hipErrorNotFound * * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, hipCtxGetCurrent, * hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize, hipCtxGetDevice @@ -5225,6 +5489,16 @@ hipError_t hipFuncGetAttributes(struct hipFuncAttributes* attr, const void* func * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDeviceFunction */ hipError_t hipFuncGetAttribute(int* value, hipFunction_attribute attrib, hipFunction_t hfunc); +/** + * @brief Gets pointer to device entry function that matches entry function symbolPtr. + * + * @param [out] functionPtr Device entry function + * @param [in] symbolPtr Pointer to device entry function to search for + * + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction + * + */ +hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr); /** * @brief returns the handle of the texture reference with the name from the module. * @@ -5646,12 +5920,26 @@ hipError_t hipLaunchKernel(const void* function_address, /** * @brief Enqueues a host function call in a stream. * - * @param [in] stream - stream to enqueue work to. - * @param [in] fn - function to call once operations enqueued preceeding are complete. + * @param [in] stream - The stream to enqueue work in. + * @param [in] fn - The function to call once enqueued preceeding operations are complete. * @param [in] userData - User-specified data to be passed to the function. + * * @returns #hipSuccess, #hipErrorInvalidResourceHandle, #hipErrorInvalidValue, * #hipErrorNotSupported - * @warning : This API is marked as beta, meaning, while this is feature complete, + * + * The host function to call in this API will be executed after the preceding operations in + * the stream are complete. The function is a blocking operation that blocks operations in the + * stream that follow it, until the function is returned. + * Event synchronization and internal callback functions make sure enqueued operations will + * execute in order, in the stream. + * + * The host function must not make any HIP API calls. The host function is non-reentrant. It must + * not perform sychronization with any operation that may depend on other processing execution + * but is not enqueued to run earlier in the stream. + * + * Host functions that are enqueued respectively in different non-blocking streams can run concurrently. + * + * @warning This API is marked as beta, meaning, while this is feature complete, * it is still open to changes and may have outstanding issues. */ hipError_t hipLaunchHostFunc(hipStream_t stream, hipHostFn_t fn, void* userData); @@ -6181,7 +6469,7 @@ hipError_t hipGetTextureAlignmentOffset( DEPRECATED(DEPRECATED_MSG) hipError_t hipUnbindTexture(const textureReference* tex); /** - * @brief Gets the the address for a texture reference. + * @brief Gets the address for a texture reference. * * @param [out] dev_ptr Pointer of device address. * @param [in] texRef Pointer of texture reference. @@ -6564,6 +6852,30 @@ int hipGetStreamDeviceId(hipStream_t stream); */ hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode); +/** +* @brief Begins graph capture on a stream to an existing graph. +* +* @param [in] stream - Stream to initiate capture. +* @param [in] graph - Graph to capture into. +* @param [in] dependencies - Dependencies of the first node captured in the stream. Can be NULL if +* numDependencies is 0. +* @param [in] dependencyData - Optional array of data associated with each dependency. +* @param [in] numDependencies - Number of dependencies. +* @param [in] mode - Controls the interaction of this capture sequence with other API calls that +are not safe. +* +* @returns #hipSuccess, #hipErrorInvalidValue +* +* @warning : param "const hipGraphEdgeData* dependencyData" is currently not supported and has to +passed as nullptr. This API is marked as beta, meaning, while this is feature complete, it is still +open to changes and may have outstanding issues. +* +*/ +hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph, + const hipGraphNode_t* dependencies, + const hipGraphEdgeData* dependencyData, + size_t numDependencies, hipStreamCaptureMode mode); + /** * @brief Ends capture on a stream, returning the captured graph. * @@ -6902,6 +7214,19 @@ hipError_t hipGraphInstantiate(hipGraphExec_t* pGraphExec, hipGraph_t graph, hipError_t hipGraphInstantiateWithFlags(hipGraphExec_t* pGraphExec, hipGraph_t graph, unsigned long long flags); +/** + * @brief Creates an executable graph from a graph. + * + * @param [out] pGraphExec - pointer to instantiated executable graph that is created. + * @param [in] graph - instance of graph to instantiate. + * @param [in] instantiateParams - Graph Instantiate Params + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphInstantiateWithParams(hipGraphExec_t* pGraphExec, hipGraph_t graph, + hipGraphInstantiateParams *instantiateParams); /** * @brief launches an executable graph in a stream * @@ -6926,6 +7251,22 @@ hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream); */ hipError_t hipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream); +/** + * @brief Creates a kernel execution node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - pointer to the dependencies on the kernel execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] nodeParams - pointer to the parameters for the node. + * @returns #hipSuccess, #hipErrorInvalidValue. + * @warning : This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipGraphNodeParams *nodeParams); + /** * @brief Destroys an executable graph * @@ -8906,6 +9247,7 @@ static inline hipError_t hipMallocManaged(T** devPtr, size_t size, return hipMallocManaged((void**)devPtr, size, flags); } + #endif #endif // doxygen end HIP API diff --git a/third_party/amd/backend/include/hip/hip_version.h b/third_party/amd/backend/include/hip/hip_version.h index 0c64f38b1f01..bab5288f806f 100644 --- a/third_party/amd/backend/include/hip/hip_version.h +++ b/third_party/amd/backend/include/hip/hip_version.h @@ -4,9 +4,9 @@ #define HIP_VERSION_H #define HIP_VERSION_MAJOR 6 -#define HIP_VERSION_MINOR 1 -#define HIP_VERSION_PATCH 40091 -#define HIP_VERSION_GITHASH "a8dbc0c19" +#define HIP_VERSION_MINOR 2 +#define HIP_VERSION_PATCH 41134 +#define HIP_VERSION_GITHASH "65d174c3e" #define HIP_VERSION_BUILD_ID 0 #define HIP_VERSION_BUILD_NAME "" #define HIP_VERSION (HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR * 100000 + HIP_VERSION_PATCH) diff --git a/third_party/amd/backend/include/hip/hiprtc.h b/third_party/amd/backend/include/hip/hiprtc.h index 88e9094d848c..e10acbfe09c0 100644 --- a/third_party/amd/backend/include/hip/hiprtc.h +++ b/third_party/amd/backend/include/hip/hiprtc.h @@ -67,32 +67,32 @@ typedef enum hiprtcResult { */ typedef enum hiprtcJIT_option { - HIPRTC_JIT_MAX_REGISTERS = 0, ///< Maximum registers may be used in a thread, passed to compiler - HIPRTC_JIT_THREADS_PER_BLOCK, ///< Number of thread per block - HIPRTC_JIT_WALL_TIME, ///< Value for total wall clock time - HIPRTC_JIT_INFO_LOG_BUFFER, ///< Pointer to the buffer with logged information - HIPRTC_JIT_INFO_LOG_BUFFER_SIZE_BYTES, ///< Size of the buffer in bytes for logged info - HIPRTC_JIT_ERROR_LOG_BUFFER, ///< Pointer to the buffer with logged error(s) - HIPRTC_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, ///< Size of the buffer in bytes for logged error(s) - HIPRTC_JIT_OPTIMIZATION_LEVEL, ///< Value of optimization level for generated codes - HIPRTC_JIT_TARGET_FROM_HIPCONTEXT, ///< The target context, which is the default - HIPRTC_JIT_TARGET, ///< JIT target - HIPRTC_JIT_FALLBACK_STRATEGY, ///< Fallback strategy - HIPRTC_JIT_GENERATE_DEBUG_INFO, ///< Generate debug information - HIPRTC_JIT_LOG_VERBOSE, ///< Generate log verbose - HIPRTC_JIT_GENERATE_LINE_INFO, ///< Generate line number information - HIPRTC_JIT_CACHE_MODE, ///< Set cache mode - HIPRTC_JIT_NEW_SM3X_OPT, ///< @deprecated New SM3X option. - HIPRTC_JIT_FAST_COMPILE, ///< Set fast compile - HIPRTC_JIT_GLOBAL_SYMBOL_NAMES, ///< Array of device symbol names to be relocated to the host - HIPRTC_JIT_GLOBAL_SYMBOL_ADDRESS, ///< Array of host addresses to be relocated to the device - HIPRTC_JIT_GLOBAL_SYMBOL_COUNT, ///< Number of symbol count. - HIPRTC_JIT_LTO, ///< @deprecated Enable link-time optimization for device code - HIPRTC_JIT_FTZ, ///< @deprecated Set single-precision denormals. - HIPRTC_JIT_PREC_DIV, ///< @deprecated Set single-precision floating-point division and + HIPRTC_JIT_MAX_REGISTERS = 0, ///< CUDA Only Maximum registers may be used in a thread, passed to compiler + HIPRTC_JIT_THREADS_PER_BLOCK, ///< CUDA Only Number of thread per block + HIPRTC_JIT_WALL_TIME, ///< CUDA Only Value for total wall clock time + HIPRTC_JIT_INFO_LOG_BUFFER, ///< CUDA Only Pointer to the buffer with logged information + HIPRTC_JIT_INFO_LOG_BUFFER_SIZE_BYTES, ///< CUDA Only Size of the buffer in bytes for logged info + HIPRTC_JIT_ERROR_LOG_BUFFER, ///< CUDA Only Pointer to the buffer with logged error(s) + HIPRTC_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, ///< CUDA Only Size of the buffer in bytes for logged error(s) + HIPRTC_JIT_OPTIMIZATION_LEVEL, ///< Value of optimization level for generated codes, acceptable options -O0, -O1, -O2, -O3 + HIPRTC_JIT_TARGET_FROM_HIPCONTEXT, ///< CUDA Only The target context, which is the default + HIPRTC_JIT_TARGET, ///< CUDA Only JIT target + HIPRTC_JIT_FALLBACK_STRATEGY, ///< CUDA Only Fallback strategy + HIPRTC_JIT_GENERATE_DEBUG_INFO, ///< CUDA Only Generate debug information + HIPRTC_JIT_LOG_VERBOSE, ///< CUDA Only Generate log verbose + HIPRTC_JIT_GENERATE_LINE_INFO, ///< CUDA Only Generate line number information + HIPRTC_JIT_CACHE_MODE, ///< CUDA Only Set cache mode + HIPRTC_JIT_NEW_SM3X_OPT, ///< @deprecated CUDA Only New SM3X option. + HIPRTC_JIT_FAST_COMPILE, ///< CUDA Only Set fast compile + HIPRTC_JIT_GLOBAL_SYMBOL_NAMES, ///< CUDA Only Array of device symbol names to be relocated to the host + HIPRTC_JIT_GLOBAL_SYMBOL_ADDRESS, ///< CUDA Only Array of host addresses to be relocated to the device + HIPRTC_JIT_GLOBAL_SYMBOL_COUNT, ///< CUDA Only Number of symbol count. + HIPRTC_JIT_LTO, ///< @deprecated CUDA Only Enable link-time optimization for device code + HIPRTC_JIT_FTZ, ///< @deprecated CUDA Only Set single-precision denormals. + HIPRTC_JIT_PREC_DIV, ///< @deprecated CUDA Only Set single-precision floating-point division and ///< reciprocals - HIPRTC_JIT_PREC_SQRT, ///< @deprecated Set single-precision floating-point square root - HIPRTC_JIT_FMA, ///< @deprecated Enable floating-point multiplies and adds/subtracts operations + HIPRTC_JIT_PREC_SQRT, ///< @deprecated CUDA Only Set single-precision floating-point square root + HIPRTC_JIT_FMA, ///< @deprecated CUDA Only Enable floating-point multiplies and adds/subtracts operations HIPRTC_JIT_NUM_OPTIONS, ///< Number of options HIPRTC_JIT_IR_TO_ISA_OPT_EXT = 10000, ///< Linker options to be passed on to compiler /// @note Only supported for the AMD platform. diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 51aa389a0681..74f15d7d7ab6 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -75,7 +75,8 @@ enum { ELFABIVERSION_AMDGPU_HSA_V2 = 0, ELFABIVERSION_AMDGPU_HSA_V3 = 1, ELFABIVERSION_AMDGPU_HSA_V4 = 2, - ELFABIVERSION_AMDGPU_HSA_V5 = 3 + ELFABIVERSION_AMDGPU_HSA_V5 = 3, + ELFABIVERSION_AMDGPU_HSA_V6 = 4, }; // AMDGPU specific e_flags. @@ -87,6 +88,7 @@ enum : unsigned { EF_AMDGPU_MACH_NONE = 0x000, // AMDGCN-based processors. + // clang-format off EF_AMDGPU_MACH_AMDGCN_GFX600 = 0x020, EF_AMDGPU_MACH_AMDGCN_GFX601 = 0x021, EF_AMDGPU_MACH_AMDGCN_GFX700 = 0x022, @@ -127,13 +129,25 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX1036 = 0x045, EF_AMDGPU_MACH_AMDGCN_GFX1101 = 0x046, EF_AMDGPU_MACH_AMDGCN_GFX1102 = 0x047, + EF_AMDGPU_MACH_AMDGCN_GFX1200 = 0x048, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X49 = 0x049, EF_AMDGPU_MACH_AMDGCN_GFX1151 = 0x04a, EF_AMDGPU_MACH_AMDGCN_GFX941 = 0x04b, EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, + EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, + EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, + EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, + EF_AMDGPU_MACH_AMDGCN_GFX10_3_GENERIC = 0x053, + EF_AMDGPU_MACH_AMDGCN_GFX11_GENERIC = 0x054, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X55 = 0x055, + // clang-format on // First/last AMDGCN-based processors. EF_AMDGPU_MACH_AMDGCN_FIRST = EF_AMDGPU_MACH_AMDGCN_GFX600, - EF_AMDGPU_MACH_AMDGCN_LAST = EF_AMDGPU_MACH_AMDGCN_GFX942, + EF_AMDGPU_MACH_AMDGCN_LAST = EF_AMDGPU_MACH_AMDGCN_GFX11_GENERIC, // Indicates if the "xnack" target feature is enabled for all code contained // in the object. @@ -159,8 +173,7 @@ enum : unsigned { // XNACK selection mask for EF_AMDGPU_FEATURE_XNACK_* values. // - // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4, - // ELFABIVERSION_AMDGPU_HSA_V5. + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4. EF_AMDGPU_FEATURE_XNACK_V4 = 0x300, // XNACK is not supported. EF_AMDGPU_FEATURE_XNACK_UNSUPPORTED_V4 = 0x000, @@ -173,8 +186,7 @@ enum : unsigned { // SRAMECC selection mask for EF_AMDGPU_FEATURE_SRAMECC_* values. // - // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4, - // ELFABIVERSION_AMDGPU_HSA_V5. + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4. EF_AMDGPU_FEATURE_SRAMECC_V4 = 0xc00, // SRAMECC is not supported. EF_AMDGPU_FEATURE_SRAMECC_UNSUPPORTED_V4 = 0x000, @@ -184,6 +196,21 @@ enum : unsigned { EF_AMDGPU_FEATURE_SRAMECC_OFF_V4 = 0x800, // SRAMECC is on. EF_AMDGPU_FEATURE_SRAMECC_ON_V4 = 0xc00, + + // Generic target versioning. This is contained in the list byte of EFLAGS. + EF_AMDGPU_GENERIC_VERSION = 0xff000000, + EF_AMDGPU_GENERIC_VERSION_OFFSET = 24, + EF_AMDGPU_GENERIC_VERSION_MIN = 1, + EF_AMDGPU_GENERIC_VERSION_MAX = 0xff, +}; + +// ELF Relocation types for AMDGPU. +enum : unsigned { + R_AMDGPU_ABS32_LO = 1, + R_AMDGPU_ABS32_HI = 2, + R_AMDGPU_ABS64 = 3, + R_AMDGPU_ABS32 = 6, + R_AMDGPU_RELATIVE64 = 13, }; } // end namespace ELF @@ -245,14 +272,14 @@ typedef enum { // ELF Symbol Flag Enumeration Values. #define STF_AMDGPU_HSA_CONST AMDGPU_HSA_SYMBOL_FLAG_CONST -// AMD GPU Relocation Type Enumeration Values. -#define R_AMDGPU_NONE 0 -#define R_AMDGPU_32_LOW 1 -#define R_AMDGPU_32_HIGH 2 -#define R_AMDGPU_64 3 -#define R_AMDGPU_INIT_SAMPLER 4 -#define R_AMDGPU_INIT_IMAGE 5 -#define R_AMDGPU_RELATIVE64 13 +// Legacy/V1 AMD GPU Relocation Type Enumeration Values. +#define R_AMDGPU_V1_NONE 0 +#define R_AMDGPU_V1_32_LOW 1 +#define R_AMDGPU_V1_32_HIGH 2 +#define R_AMDGPU_V1_64 3 +#define R_AMDGPU_V1_INIT_SAMPLER 4 +#define R_AMDGPU_V1_INIT_IMAGE 5 +#define R_AMDGPU_V1_RELATIVE64 13 // AMD GPU Note Type Enumeration Values. #define NT_AMD_HSA_CODE_OBJECT_VERSION 1 diff --git a/third_party/amd/backend/include/hsa/hsa.h b/third_party/amd/backend/include/hsa/hsa.h index 9520bd870c9c..1ad714c44c2d 100644 --- a/third_party/amd/backend/include/hsa/hsa.h +++ b/third_party/amd/backend/include/hsa/hsa.h @@ -598,10 +598,14 @@ typedef enum { * AqlProfile extension. */ HSA_EXTENSION_AMD_AQLPROFILE = 0x202, + /** + * PC Sampling extension. + */ + HSA_EXTENSION_AMD_PC_SAMPLING = 0x203, /** * Last AMD extension. */ - HSA_AMD_LAST_EXTENSION = 0x202 + HSA_AMD_LAST_EXTENSION = 0x203 } hsa_extension_t; /** @@ -5656,7 +5660,12 @@ typedef enum { * undefined if the symbol is not an indirect function. The type of this * attribute is uint32_t. */ - HSA_CODE_SYMBOL_INFO_INDIRECT_FUNCTION_CALL_CONVENTION = 16 + HSA_CODE_SYMBOL_INFO_INDIRECT_FUNCTION_CALL_CONVENTION = 16, + /** + * Wavefront size used by the kernel. The value of this attribute is either + * 32 or 64. The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_WAVEFRONT_SIZE = 19 } hsa_code_symbol_info_t; /** diff --git a/third_party/amd/backend/include/hsa/hsa_api_trace.h b/third_party/amd/backend/include/hsa/hsa_api_trace.h index e46c777af13f..2a0f59df3b82 100644 --- a/third_party/amd/backend/include/hsa/hsa_api_trace.h +++ b/third_party/amd/backend/include/hsa/hsa_api_trace.h @@ -44,39 +44,26 @@ #define HSA_RUNTIME_INC_HSA_API_TRACE_H #include "hsa.h" +#include "hsa_api_trace_version.h" #ifdef AMD_INTERNAL_BUILD #include "hsa_ext_image.h" #include "hsa_ext_amd.h" #include "hsa_ext_finalize.h" #include "hsa_amd_tool.h" +#include "hsa_ven_amd_pc_sampling.h" #else #include "inc/hsa_ext_image.h" #include "inc/hsa_ext_amd.h" #include "inc/hsa_ext_finalize.h" #include "inc/hsa_amd_tool.h" +#include "inc/hsa_ven_amd_pc_sampling.h" #endif #include #include #include -// Major Ids of the Api tables exported by Hsa Core Runtime -#define HSA_API_TABLE_MAJOR_VERSION 0x03 -#define HSA_CORE_API_TABLE_MAJOR_VERSION 0x02 -#define HSA_AMD_EXT_API_TABLE_MAJOR_VERSION 0x02 -#define HSA_FINALIZER_API_TABLE_MAJOR_VERSION 0x02 -#define HSA_IMAGE_API_TABLE_MAJOR_VERSION 0x02 -#define HSA_AQLPROFILE_API_TABLE_MAJOR_VERSION 0x01 -#define HSA_TOOLS_API_TABLE_MAJOR_VERSION 0x01 - -// Step Ids of the Api tables exported by Hsa Core Runtime -#define HSA_API_TABLE_STEP_VERSION 0x00 -#define HSA_CORE_API_TABLE_STEP_VERSION 0x00 -#define HSA_AMD_EXT_API_TABLE_STEP_VERSION 0x01 -#define HSA_FINALIZER_API_TABLE_STEP_VERSION 0x00 -#define HSA_IMAGE_API_TABLE_STEP_VERSION 0x00 -#define HSA_AQLPROFILE_API_TABLE_STEP_VERSION 0x00 -#define HSA_TOOLS_API_TABLE_STEP_VERSION 0x00 +// Table MAJOR_VERSION and STEP_VERSION defines have moved to hsa_api_trace_version.h // Min function used to copy Api Tables static inline uint32_t Min(const uint32_t a, const uint32_t b) { @@ -191,6 +178,19 @@ struct ImageExtTable { decltype(hsa_ext_image_create_with_layout)* hsa_ext_image_create_with_layout_fn; }; +// Table to export HSA PC Sampling Extension Apis +struct PcSamplingExtTable { + ApiTableVersion version; + decltype(hsa_ven_amd_pcs_iterate_configuration)* hsa_ven_amd_pcs_iterate_configuration_fn; + decltype(hsa_ven_amd_pcs_create)* hsa_ven_amd_pcs_create_fn; + decltype(hsa_ven_amd_pcs_create_from_id)* hsa_ven_amd_pcs_create_from_id_fn; + decltype(hsa_ven_amd_pcs_destroy)* hsa_ven_amd_pcs_destroy_fn; + decltype(hsa_ven_amd_pcs_start)* hsa_ven_amd_pcs_start_fn; + decltype(hsa_ven_amd_pcs_stop)* hsa_ven_amd_pcs_stop_fn; + decltype(hsa_ven_amd_pcs_flush)* hsa_ven_amd_pcs_flush_fn; +}; + + // Table to export AMD Extension Apis struct AmdExtTable { ApiTableVersion version; @@ -263,6 +263,8 @@ struct AmdExtTable { decltype(hsa_amd_vmem_get_alloc_properties_from_handle)* hsa_amd_vmem_get_alloc_properties_from_handle_fn; decltype(hsa_amd_agent_set_async_scratch_limit)* hsa_amd_agent_set_async_scratch_limit_fn; + decltype(hsa_amd_queue_get_info)* hsa_amd_queue_get_info_fn; + decltype(hsa_amd_vmem_address_reserve_align)* hsa_amd_vmem_address_reserve_align_fn; }; // Table to export HSA Core Runtime Apis @@ -464,6 +466,9 @@ struct HsaApiTable { // Table of function pointers for tools to use ToolsApiTable* tools_; + + // Table of function pointers to AMD PC Sampling Extension + PcSamplingExtTable* pc_sampling_ext_; }; // Structure containing instances of different api tables @@ -474,6 +479,7 @@ struct HsaApiTableContainer { FinalizerExtTable finalizer_ext; ImageExtTable image_ext; ToolsApiTable tools; + PcSamplingExtTable pc_sampling_ext; // Default initialization of a container instance HsaApiTableContainer() { @@ -505,6 +511,11 @@ struct HsaApiTableContainer { tools.version.minor_id = sizeof(ToolsApiTable); tools.version.step_id = HSA_TOOLS_API_TABLE_STEP_VERSION; root.tools_ = &tools; + + pc_sampling_ext.version.major_id = HSA_PC_SAMPLING_API_TABLE_MAJOR_VERSION; + pc_sampling_ext.version.minor_id = sizeof(PcSamplingExtTable); + pc_sampling_ext.version.step_id = HSA_PC_SAMPLING_API_TABLE_STEP_VERSION; + root.pc_sampling_ext_ = &pc_sampling_ext; } }; @@ -562,5 +573,7 @@ static void inline copyTables(const HsaApiTable* src, HsaApiTable* dest) { copyElement(&dest->image_ext_->version, &src->image_ext_->version); if ((offsetof(HsaApiTable, tools_) < dest->version.minor_id)) copyElement(&dest->tools_->version, &src->tools_->version); + if ((offsetof(HsaApiTable, pc_sampling_ext_) < dest->version.minor_id)) + copyElement(&dest->pc_sampling_ext_->version, &src->pc_sampling_ext_->version); } #endif diff --git a/third_party/amd/backend/include/hsa/hsa_api_trace_version.h b/third_party/amd/backend/include/hsa/hsa_api_trace_version.h new file mode 100644 index 000000000000..3393a776207b --- /dev/null +++ b/third_party/amd/backend/include/hsa/hsa_api_trace_version.h @@ -0,0 +1,68 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H +#define HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H + +// CODE IN THIS FILE **MUST** BE C-COMPATIBLE + +// Major Ids of the Api tables exported by Hsa Core Runtime +#define HSA_API_TABLE_MAJOR_VERSION 0x03 +#define HSA_CORE_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_AMD_EXT_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_FINALIZER_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_IMAGE_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_AQLPROFILE_API_TABLE_MAJOR_VERSION 0x01 +#define HSA_TOOLS_API_TABLE_MAJOR_VERSION 0x01 +#define HSA_PC_SAMPLING_API_TABLE_MAJOR_VERSION 0x01 + +// Step Ids of the Api tables exported by Hsa Core Runtime +#define HSA_API_TABLE_STEP_VERSION 0x01 +#define HSA_CORE_API_TABLE_STEP_VERSION 0x00 +#define HSA_AMD_EXT_API_TABLE_STEP_VERSION 0x03 +#define HSA_FINALIZER_API_TABLE_STEP_VERSION 0x00 +#define HSA_IMAGE_API_TABLE_STEP_VERSION 0x00 +#define HSA_AQLPROFILE_API_TABLE_STEP_VERSION 0x00 +#define HSA_TOOLS_API_TABLE_STEP_VERSION 0x00 +#define HSA_PC_SAMPLING_API_TABLE_STEP_VERSION 0x00 + +#endif // HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H diff --git a/third_party/amd/backend/include/hsa/hsa_ext_amd.h b/third_party/amd/backend/include/hsa/hsa_ext_amd.h index 187bcd958707..f9f60edeb9d0 100644 --- a/third_party/amd/backend/include/hsa/hsa_ext_amd.h +++ b/third_party/amd/backend/include/hsa/hsa_ext_amd.h @@ -47,16 +47,19 @@ #include "hsa.h" #include "hsa_ext_image.h" +#include "hsa_ven_amd_pc_sampling.h" -/* +/** * - 1.0 - initial version * - 1.1 - dmabuf export * - 1.2 - hsa_amd_memory_async_copy_on_engine * - 1.3 - HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED pool * - 1.4 - Virtual Memory API + * - 1.5 - hsa_amd_agent_info: HSA_AMD_AGENT_INFO_MEMORY_PROPERTIES + * - 1.6 - Virtual Memory API: hsa_amd_vmem_address_reserve_align */ #define HSA_AMD_INTERFACE_VERSION_MAJOR 1 -#define HSA_AMD_INTERFACE_VERSION_MINOR 4 +#define HSA_AMD_INTERFACE_VERSION_MINOR 6 #ifdef __cplusplus extern "C" { @@ -221,6 +224,11 @@ enum { * Exceeded number of VGPRs available on this agent */ HSA_STATUS_ERROR_OUT_OF_REGISTERS = 45, + + /** + * Resource is busy or temporarily unavailable + */ + HSA_STATUS_ERROR_RESOURCE_BUSY = 46, }; /** @@ -1176,7 +1184,11 @@ typedef enum hsa_amd_memory_pool_flag_s { * connection. Atomic memory operations on these memory buffers are not * guaranteed to be visible at system scope. */ - HSA_AMD_MEMORY_POOL_PCIE_FLAG = 1, + HSA_AMD_MEMORY_POOL_PCIE_FLAG = (1 << 0), + /** + * Allocates physically contiguous memory + */ + HSA_AMD_MEMORY_POOL_CONTIGUOUS_FLAG = (1 << 1), } hsa_amd_memory_pool_flag_t; @@ -2783,7 +2795,7 @@ hsa_status_t hsa_amd_portable_export_dmabuf(const void* ptr, size_t size, int* d */ hsa_status_t hsa_amd_portable_close_dmabuf(int dmabuf); -/* +/** * @brief Allocate a reserved address range * * Reserve a virtual address range. The size must be a multiple of the system page size. @@ -2803,11 +2815,39 @@ hsa_status_t hsa_amd_portable_close_dmabuf(int dmabuf); * * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources to allocate an address * range of this size. + * + * Note that this API will be deprecated in a future release and replaced by + * hsa_amd_vmem_address_reserve_align */ hsa_status_t hsa_amd_vmem_address_reserve(void** va, size_t size, uint64_t address, uint64_t flags); -/* +/** + * @brief Allocate a reserved address range + * + * Reserve a virtual address range. The size must be a multiple of the system page size. + * If it is not possible to allocate the address specified by @p address, then @p va will be + * a different address range. + * Address range should be released by calling hsa_amd_vmem_address_free. + * + * @param[out] va virtual address allocated + * @param[in] size of address range requested + * @param[in] address requested + * @param[in] alignment requested. 0 for default. Must be >= page-size and a power of 2 + * @param[in] flags currently unsupported + * + * @retval ::HSA_STATUS_SUCCESS Address range allocated successfully + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources to allocate an address + * range of this size. + */ +hsa_status_t hsa_amd_vmem_address_reserve_align(void** va, size_t size, uint64_t address, + uint64_t alignment, uint64_t flags); + +/** * @brief Free a reserved address range * * Free a previously allocated address range. The size must match the size of a previously @@ -2841,7 +2881,7 @@ typedef enum { MEMORY_TYPE_PINNED, } hsa_amd_memory_type_t; -/* +/** * @brief Create a virtual memory handle * * Create a virtual memory handle within this pool @@ -2870,7 +2910,7 @@ hsa_status_t hsa_amd_vmem_handle_create(hsa_amd_memory_pool_t pool, size_t size, hsa_amd_memory_type_t type, uint64_t flags, hsa_amd_vmem_alloc_handle_t* memory_handle); -/* +/** * @brief Release a virtual memory handle * * @param[in] memory handle that was previously allocated @@ -2881,7 +2921,7 @@ hsa_status_t hsa_amd_vmem_handle_create(hsa_amd_memory_pool_t pool, size_t size, */ hsa_status_t hsa_amd_vmem_handle_release(hsa_amd_vmem_alloc_handle_t memory_handle); -/* +/** * @brief Map a virtual memory handle * * Map a virtual memory handle to a reserved address range. The virtual address requested must be @@ -2907,7 +2947,7 @@ hsa_status_t hsa_amd_vmem_handle_release(hsa_amd_vmem_alloc_handle_t memory_hand hsa_status_t hsa_amd_vmem_map(void* va, size_t size, size_t in_offset, hsa_amd_vmem_alloc_handle_t memory_handle, uint64_t flags); -/* +/** * @brief Unmap a virtual memory handle * * Unmap previously mapped virtual address range @@ -2930,7 +2970,7 @@ typedef struct hsa_amd_memory_access_desc_s { hsa_agent_t agent_handle; } hsa_amd_memory_access_desc_t; -/* +/** * @brief Make a memory mapping accessible * * Make previously mapped virtual address accessible to specific agents. @p size must be equal to @@ -2959,7 +2999,7 @@ hsa_status_t hsa_amd_vmem_set_access(void* va, size_t size, const hsa_amd_memory_access_desc_t* desc, size_t desc_cnt); -/* +/** * @brief Get current access permissions for memory mapping * * Get access permissions for memory mapping for specific agent. @@ -2980,7 +3020,7 @@ hsa_status_t hsa_amd_vmem_set_access(void* va, size_t size, hsa_status_t hsa_amd_vmem_get_access(void* va, hsa_access_permission_t* perms, hsa_agent_t agent_handle); -/* +/** * @brief Get an exportable shareable handle * * Get an exportable shareable handle for a memory_handle. This shareabl handle can then be used to @@ -3003,7 +3043,7 @@ hsa_status_t hsa_amd_vmem_get_access(void* va, hsa_access_permission_t* perms, hsa_status_t hsa_amd_vmem_export_shareable_handle(int* dmabuf_fd, hsa_amd_vmem_alloc_handle_t handle, uint64_t flags); -/* +/** * @brief Import a shareable handle * * Import a shareable handle for a memory handle. Importing a shareable handle that has been closed @@ -3023,7 +3063,7 @@ hsa_status_t hsa_amd_vmem_export_shareable_handle(int* dmabuf_fd, hsa_status_t hsa_amd_vmem_import_shareable_handle(int dmabuf_fd, hsa_amd_vmem_alloc_handle_t* handle); -/* +/** * @brief Returns memory handle for mapped memory * * Return a memory handle for previously mapped memory. The handle will be the same value of handle @@ -3040,19 +3080,19 @@ hsa_status_t hsa_amd_vmem_import_shareable_handle(int dmabuf_fd, hsa_status_t hsa_amd_vmem_retain_alloc_handle(hsa_amd_vmem_alloc_handle_t* memory_handle, void* addr); -/* -* @brief Returns the current allocation properties of a handle -* -* Returns the allocation properties of an existing handle -* -* @param[in] memory_handle memory handle to be queried -* @param[out] pool memory pool that owns this handle -* @param[out] memory type - -* @retval ::HSA_STATUS_SUCCESS -* -* @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory_handle -*/ +/** + * @brief Returns the current allocation properties of a handle + * + * Returns the allocation properties of an existing handle + * + * @param[in] memory_handle memory handle to be queried + * @param[out] pool memory pool that owns this handle + * @param[out] memory type + + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory_handle + */ hsa_status_t hsa_amd_vmem_get_alloc_properties_from_handle( hsa_amd_vmem_alloc_handle_t memory_handle, hsa_amd_memory_pool_t* pool, hsa_amd_memory_type_t* type); @@ -3084,6 +3124,22 @@ hsa_status_t hsa_amd_vmem_get_alloc_properties_from_handle( */ hsa_status_t HSA_API hsa_amd_agent_set_async_scratch_limit(hsa_agent_t agent, size_t threshold); +typedef enum { + /* + * Returns the agent that owns the underlying HW queue. + * The type of this attribute is hsa_agent_t. + */ + HSA_AMD_QUEUE_INFO_AGENT, + /* + * Returns the doorbell ID of the completion signal of the queue + * The type of this attribute is uint64_t. + */ + HSA_AMD_QUEUE_INFO_DOORBELL_ID, +} hsa_queue_info_attribute_t; + +hsa_status_t hsa_amd_queue_get_info(hsa_queue_t* queue, hsa_queue_info_attribute_t attribute, + void* value); + #ifdef __cplusplus } // end extern "C" block #endif diff --git a/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h b/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h index 32ca6b7320bb..0022c0d8b8b6 100644 --- a/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h +++ b/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h @@ -149,61 +149,61 @@ hsa_status_t hsa_ven_amd_aqlprofile_validate_event( // All parameters are generic and if not applicable for a specific // profile configuration then error status will be returned. typedef enum { - /* - * Select the target compute unit (wgp) for profiling. - */ + /** + * Select the target compute unit (wgp) for profiling. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET = 0, - /* - * VMID Mask - */ + /** + * VMID Mask + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_VM_ID_MASK = 1, - /* - * Legacy. Deprecated. - */ + /** + * Legacy. Deprecated. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_MASK = 2, - /* - * Legacy. Deprecated. - */ + /** + * Legacy. Deprecated. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_TOKEN_MASK = 3, - /* - * Legacy. Deprecated. - */ + /** + * Legacy. Deprecated. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_TOKEN_MASK2 = 4, - /* - * Shader engine mask for selection. - */ + /** + * Shader engine mask for selection. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SE_MASK = 5, - /* - * Legacy. Deprecated. - */ + /** + * Legacy. Deprecated. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SAMPLE_RATE = 6, - /* - * Legacy. Deprecated. - */ + /** + * Legacy. Deprecated. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_K_CONCURRENT = 7, - /* - * Set SIMD Mask (GFX9) or SIMD ID for collection (Navi) - */ + /** + * Set SIMD Mask (GFX9) or SIMD ID for collection (Navi) + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION = 8, - /* - * Set true for occupancy collection only. - */ + /** + * Set true for occupancy collection only. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_OCCUPANCY_MODE = 9, - /* - * ATT collection max data size, in MB. Shared among shader engines. - */ + /** + * ATT collection max data size, in MB. Shared among shader engines. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE = 10, - /* - * Mask of which compute units to generate perfcounters. GFX9 only. - */ + /** + * Mask of which compute units to generate perfcounters. GFX9 only. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_MASK = 240, - /* - * Select collection period for perfcounters. GFX9 only. - */ + /** + * Select collection period for perfcounters. GFX9 only. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_CTRL = 241, - /* - * Select perfcounter ID (SQ block) for collection. GFX9 only. - */ + /** + * Select perfcounter ID (SQ block) for collection. GFX9 only. + */ HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_NAME = 242, } hsa_ven_amd_aqlprofile_parameter_name_t; @@ -365,11 +365,11 @@ hsa_status_t hsa_ven_amd_aqlprofile_error_string( /** * @brief Callback for iteration of all possible event coordinate IDs and coordinate names. -*/ + */ typedef hsa_status_t(*hsa_ven_amd_aqlprofile_eventname_callback_t)(int id, const char* name); /** * @brief Iterate over all possible event coordinate IDs and their names. -*/ + */ hsa_status_t hsa_ven_amd_aqlprofile_iterate_event_ids(hsa_ven_amd_aqlprofile_eventname_callback_t); /** @@ -380,7 +380,7 @@ hsa_status_t hsa_ven_amd_aqlprofile_iterate_event_ids(hsa_ven_amd_aqlprofile_eve * @param coordinate The coordinate, in the range [0,extent-1]. * @param name Coordinate name as in _iterate_event_ids. * @param userdata Userdata returned from _iterate_event_coord function. -*/ + */ typedef hsa_status_t(*hsa_ven_amd_aqlprofile_coordinate_callback_t)( int position, int id, @@ -397,7 +397,7 @@ typedef hsa_status_t(*hsa_ven_amd_aqlprofile_coordinate_callback_t)( * @param[in] sample_id aqlprofile_info_data_t.sample_id returned from _aqlprofile_iterate_data. * @param[in] callback Callback function to return the coordinates. * @param[in] userdata Arbitrary data pointer to be sent back to the user via callback. -*/ + */ hsa_status_t hsa_ven_amd_aqlprofile_iterate_event_coord( hsa_agent_t agent, hsa_ven_amd_aqlprofile_event_t event, diff --git a/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h b/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h new file mode 100644 index 000000000000..019f0ea5c960 --- /dev/null +++ b/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h @@ -0,0 +1,416 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_VEN_AMD_PC_SAMPLING_H +#define HSA_VEN_AMD_PC_SAMPLING_H + +#include "hsa.h" + +#ifdef __cplusplus +extern "C" { +#endif /*__cplusplus*/ + + +/** + * @brief HSA AMD Vendor PC Sampling APIs + * EXPERIMENTAL: All PC Sampling APIs are currently in an experimental phase and the APIs may be + * modified extensively in the future + */ + +/** + * @brief PC Sampling sample data for hosttrap sampling method + */ +typedef struct { + uint64_t pc; + uint64_t exec_mask; + uint32_t workgroup_id_x; + uint32_t workgroup_id_y; + uint32_t workgroup_id_z; + uint32_t wave_in_wg : 6; + uint32_t chiplet : 3; // Currently not used + uint32_t reserved : 23; + uint32_t hw_id; + uint32_t reserved0; + uint64_t reserved1; + uint64_t timestamp; + uint64_t correlation_id; +} perf_sample_hosttrap_v1_t; + +/** + * @brief PC Sampling sample data for stochastic sampling method + */ +typedef struct { + uint64_t pc; + uint64_t exec_mask; + uint32_t workgroup_id_x; + uint32_t workgroup_id_y; + uint32_t workgroup_id_z; + uint32_t wave_in_wg : 6; + uint32_t chiplet : 3; // Currently not used + uint32_t reserved : 23; + uint32_t hw_id; + uint32_t perf_snapshot_data; + uint32_t perf_snapshot_data1; + uint32_t perf_snapshot_data2; + uint64_t timestamp; + uint64_t correlation_id; +} perf_sample_snapshot_v1_t; + +/** + * @brief PC Sampling method kinds + */ +typedef enum { + HSA_VEN_AMD_PCS_METHOD_HOSTTRAP_V1, + HSA_VEN_AMD_PCS_METHOD_STOCHASTIC_V1 +} hsa_ven_amd_pcs_method_kind_t; + +/** + * @brief PC Sampling interval unit type + */ +typedef enum { + HSA_VEN_AMD_PCS_INTERVAL_UNITS_MICRO_SECONDS, + HSA_VEN_AMD_PCS_INTERVAL_UNITS_CLOCK_CYCLES, + HSA_VEN_AMD_PCS_INTERVAL_UNITS_INSTRUCTIONS +} hsa_ven_amd_pcs_units_t; + +/** + * @brief HSA callback function to perform the copy onto a destination buffer + * + * If data_size is 0, HSA will stop current copy operation and keep remaining data in internal + * buffers. Remaining contents of HSA internal buffers will be included in next + * hsa_ven_amd_pcs_data_ready_callback_t. HSA internal buffers can also be drained by calling + * hsa_ven_amd_pcs_flush. + * + * @param[in] hsa_callback_data private data to pass back to HSA. Provided in + * hsa_ven_amd_pcs_data_ready_callback_t + * + * @param[in] data_size size of destination buffer in bytes. + * @param[in] destination destination buffer + * @retval TBD: but could be used to indicate that there is no more data to be read. + * Or indicate an error and abort of current copy operations + */ +typedef hsa_status_t (*hsa_ven_amd_pcs_data_copy_callback_t)(void* hsa_callback_data, + size_t data_size, void* destination); + +/** + * @brief HSA callback function to to indicate that there is data ready to be copied + * + * When the client receives this callback, the client should call back @p data_copy_callback for HSA + * to perform the copy operation into an available buffer. @p data_copy_callback can be called back + * multiple times with smaller @p data_size to split the copy operation. + * + * This callback must not call ::hsa_ven_amd_pcs_flush. + * + * @param[in] client_callback_data client private data passed in via + * hsa_ven_amd_pcs_create/hsa_ven_amd_pcs_create_from_id + * @param[in] data_size size of data available to be copied + * @param[in] lost_sample_count number of lost samples since last call to + * hsa_ven_amd_pcs_data_ready_callback_t. + * @param[in] data_copy_callback callback function for HSA to perform the actual copy + * @param[in] hsa_callback_data private data to pass back to HSA + */ +typedef void (*hsa_ven_amd_pcs_data_ready_callback_t)( + void* client_callback_data, size_t data_size, size_t lost_sample_count, + hsa_ven_amd_pcs_data_copy_callback_t data_copy_callback, void* hsa_callback_data); + +/** + * @brief Opaque handle representing a sampling session. + * Two sessions having same handle value represent the same session + */ +typedef struct { + uint64_t handle; +} hsa_ven_amd_pcs_t; + +/** + * @brief PC Sampling configuration flag options + */ +typedef enum { + /* The interval for this sampling method have to be a power of 2 */ + HSA_VEN_AMD_PCS_CONFIGURATION_FLAGS_INTERVAL_POWER_OF_2 = (1 << 0) +} hsa_ven_amd_pcs_configuration_flags_t; + +/** + * @brief PC Sampling method information + * Used to provide client with list of supported PC Sampling methods + */ +typedef struct { + hsa_ven_amd_pcs_method_kind_t method; + hsa_ven_amd_pcs_units_t units; + size_t min_interval; + size_t max_interval; + uint64_t flags; +} hsa_ven_amd_pcs_configuration_t; + +/** + * @brief Callback function to iterate through list of supported PC Sampling configurations + * + * @param[in] configuration one entry for supported PC Sampling method and configuration options + * @param[in] callback_data client private callback data that was passed in when calling + * hsa_ven_amd_pcs_iterate_configuration + */ +typedef hsa_status_t (*hsa_ven_amd_pcs_iterate_configuration_callback_t)( + const hsa_ven_amd_pcs_configuration_t* configuration, void* callback_data); + +/** + * @brief Iterate through list of current supported PC Sampling configurations for this @p agent + * + * HSA will callback @p configuration_callback for each currently available PC Sampling + * configuration. The list of currently available configurations may not be the complete list of + * configurations supported on the @p agent. The list of currently available configurations may be + * reduced if the @p agent is currently handling other PC sampling sessions. + * + * @param[in] agent target agent + * @param[in] configuration_callback callback function to iterate through list of configurations + * @param[in] callback_data client private callback data + **/ +hsa_status_t hsa_ven_amd_pcs_iterate_configuration( + hsa_agent_t agent, hsa_ven_amd_pcs_iterate_configuration_callback_t configuration_callback, + void* callback_data); + +/** + * @brief Create a PC Sampling session on @p agent + * + * Allocate the resources required for a PC Sampling session. The @p method, @p units, @p interval + * parameters must be a legal configuration value, as described by the + * hsa_ven_amd_pcs_configuration_t configurations passed to the callbacks of + * hsa_ven_amd_pcs_iterate_configuration for this @p agent. + * A successfull call may restrict the list of possible PC sampling methods available to subsequent + * calls to hsa_ven_amd_pcs_iterate_configuration on the same agent as agents have limitations + * on what types of PC sampling they can perform concurrently. + * For all successful calls, hsa_ven_amd_pcs_destroy should be called to free this session. + * The session will be in a stopped/inactive state after this call + * + * @param[in] agent target agent + * @param[in] method method to use + * @param[in] units sampling units + * @param[in] interval sampling interval in @p units + * @param[in] latency expected latency in microseconds for client to provide a buffer for the data + * copy callback once HSA calls @p data_ready_callback. This is a performance hint to avoid the + * buffer filling up before the client is notified that data is ready. HSA-runtime will estimate + * how many samples are received within @p latency and call @p data_ready_callback ahead of time so + * that the client has @p latency time to allocate the buffer before the HSA-runtime internal + * buffers are full. The value of latency can be 0. + * @param[in] buffer_size size of client buffer in bytes. @p data_ready_callback will be called once + * HSA-runtime has enough samples to fill @p buffer_size. This needs to be a multiple of size of + * perf_sample_hosttrap_v1_t or size of perf_sample_snapshot_v1_t. + * @param[in] data_ready_callback client callback function that will be called when: + * 1. There is enough samples fill a buffer with @p buffer_size - estimated samples received + * within @p latency period. + * OR + * 2. When hsa_ven_amd_pcs_flush is called. + * @param[in] client_callback_data client private data to be provided back when data_ready_callback + * is called. + * @param[out] pc_sampling PC sampling session handle used to reference this session when calling + * hsa_ven_amd_pcs_start, hsa_ven_amd_pcs_stop, hsa_ven_amd_pcs_destroy + * + * @retval ::HSA_STATUS_SUCCESS session created successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT invalid parameters + * @retval ::HSA_STATUS_ERROR_RESOURCE_BUSY agent currently handling another PC Sampling session and + * cannot handle the type requested. + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Failed to allocate resources + * @retval ::HSA_STATUS_ERROR Unexpected error + **/ +hsa_status_t hsa_ven_amd_pcs_create(hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, + size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void* client_callback_data, hsa_ven_amd_pcs_t* pc_sampling); + + +/** + * @brief Creates a PC Sampling session on @p agent. Assumes that the caller provides the + * @p pcs_id generated by the previous call to the underlying driver that reserved PC sampling + * on the @p agent. + * + * Similar to the @ref hsa_ven_amd_pcs_create with the difference that it inherits an existing + * PC sampling session that was previously created in the underlying driver. + * + * Allocate the resources required for a PC Sampling session. The @p method, @p units, @p interval + * parameters must be a legal configuration value, and match the parameters that we used to create + * the underlying PC Sampling session in the underlying driver. + * A successfull call may restrict the list of possible PC sampling methods available to subsequent + * calls to hsa_ven_amd_pcs_iterate_configuration on the same agent as agents have limitations + * on what types of PC sampling they can perform concurrently. + * For all successful calls, hsa_ven_amd_pcs_destroy should be called to free this session. + * The session will be in a stopped/inactive state after this call + * + * @param[in] pcs_id ID that uniquely identifies the PC sampling session within underlying driver + * @param[in] agent target agent + * @param[in] method method to use + * @param[in] units sampling units + * @param[in] interval sampling interval in @p units + * @param[in] latency expected latency in microseconds for client to provide a buffer for the data + * copy callback once HSA calls @p data_ready_callback. This is a performance hint to avoid the + * buffer filling up before the client is notified that data is ready. HSA-runtime will estimate + * how many samples are received within @p latency and call @p data_ready_callback ahead of time so + * that the client has @p latency time to allocate the buffer before the HSA-runtime internal + * buffers are full. The value of latency can be 0. + * @param[in] buffer_size size of client buffer in bytes. @p data_ready_callback will be called once + * HSA-runtime has enough samples to fill @p buffer_size. This needs to be a multiple of size of + * perf_sample_hosttrap_v1_t or size of perf_sample_snapshot_v1_t. + * @param[in] data_ready_callback client callback function that will be called when: + * 1. There is enough samples fill a buffer with @p buffer_size - estimated samples received + * within @p latency period. + * OR + * 2. When hsa_ven_amd_pcs_flush is called. + * @param[in] client_callback_data client private data to be provided back when data_ready_callback + * is called. + * @param[out] pc_sampling PC sampling session handle used to reference this session when calling + * hsa_ven_amd_pcs_start, hsa_ven_amd_pcs_stop, hsa_ven_amd_pcs_destroy + * + * @retval ::HSA_STATUS_SUCCESS session created successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT invalid parameters + * @retval ::HSA_STATUS_ERROR_RESOURCE_BUSY agent currently handling another PC Sampling session and + * cannot handle the type requested. + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Failed to allocate resources + * @retval ::HSA_STATUS_ERROR Unexpected error + **/ +hsa_status_t hsa_ven_amd_pcs_create_from_id( + uint32_t pcs_id, hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, void* client_callback_data, + hsa_ven_amd_pcs_t* pc_sampling); + +/** + * @brief Free a PC Sampling session on @p agent + * + * Free all the resources allocated for a PC Sampling session on @p agent + * Internal buffers for this session will be lost. + * If the session was active, the session will be stopped before it is destroyed. + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session destroyed successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + * @retval ::HSA_STATUS_ERROR unexpected error + */ +hsa_status_t hsa_ven_amd_pcs_destroy(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Start a PC Sampling session + * + * Activate a PC Sampling session that was previous created. + * The session with be in a active state after this call + * If the session was already active, this will result in a no-op and will return HSA_STATUS_SUCCESS + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session started successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + * @retval ::HSA_STATUS_ERROR unexpected error + */ +hsa_status_t hsa_ven_amd_pcs_start(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Stop a PC Sampling session + * + * Stop a session that is currently active + * After a session is stopped HSA may still have some PC Sampling data in its internal buffers. + * The internal buffers can be drained using hsa_ven_amd_pcs_flush. If the internal + * buffers are not drained and the session is started again, the internal buffers will be available + * on the next data_ready_callback. + * If the session was already inactive, this will result in a no-op and will return + * HSA_STATUS_SUCCESS + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session stopped successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + */ +hsa_status_t hsa_ven_amd_pcs_stop(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Flush internal buffers for a PC Sampling session + * + * Drain internal buffers for a PC Sampling session. If internal buffers have available data, + * this trigger a data_ready_callback. + * + * The function blocks until all PC samples associated with the @p pc_sampling session + * generated prior to the function call have been communicated by invocations of + * @p data_ready_callback having completed execution. + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session flushed successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + */ +hsa_status_t hsa_ven_amd_pcs_flush(hsa_ven_amd_pcs_t pc_sampling); + +#define hsa_ven_amd_pc_sampling_1_00 + +/** + * @brief The function pointer table for the PC Sampling v1.00 extension. Can be returned by + * ::hsa_system_get_extension_table or ::hsa_system_get_major_extension_table. + */ +typedef struct hsa_ven_amd_pc_sampling_1_00_pfn_t { + hsa_status_t (*hsa_ven_amd_pcs_iterate_configuration)( + hsa_agent_t agent, hsa_ven_amd_pcs_iterate_configuration_callback_t configuration_callback, + void* callback_data); + + hsa_status_t (*hsa_ven_amd_pcs_create)(hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, + size_t latency, size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void* client_callback_data, + hsa_ven_amd_pcs_t* pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_create_from_id)( + uint32_t pcs_id, hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, void* client_callback_data, + hsa_ven_amd_pcs_t* pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_destroy)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_start)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_stop)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_flush)(hsa_ven_amd_pcs_t pc_sampling); + +} hsa_ven_amd_pc_sampling_1_00_pfn_t; + +#ifdef __cplusplus +} // end extern "C" block +#endif /*__cplusplus*/ + +#endif /* HSA_VEN_AMD_PC_SAMPLING_H */ diff --git a/third_party/amd/backend/include/roctracer/hip_ostream_ops.h b/third_party/amd/backend/include/roctracer/hip_ostream_ops.h index 13ee9ac2d379..eba2592fa305 100644 --- a/third_party/amd/backend/include/roctracer/hip_ostream_ops.h +++ b/third_party/amd/backend/include/roctracer/hip_ostream_ops.h @@ -2795,6 +2795,11 @@ inline static std::ostream& operator<<(std::ostream& out, const hipMemPoolProps& roctracer::hip_support::detail::operator<<(out, 0); std::operator<<(out, ", "); } + if (std::string("hipMemPoolProps::maxSize").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSize="); + roctracer::hip_support::detail::operator<<(out, v.maxSize); + std::operator<<(out, ", "); + } if (std::string("hipMemPoolProps::location").find(HIP_structs_regex) != std::string::npos) { std::operator<<(out, "location="); roctracer::hip_support::detail::operator<<(out, v.location); @@ -3229,17 +3234,22 @@ inline static std::ostream& operator<<(std::ostream& out, const hipAccessPolicyW std::operator<<(out, '}'); return out; } -inline static std::ostream& operator<<(std::ostream& out, const hipKernelNodeAttrValue& v) +inline static std::ostream& operator<<(std::ostream& out, const hipLaunchAttributeValue& v) { std::operator<<(out, '{'); HIP_depth_max_cnt++; if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { - if (std::string("hipKernelNodeAttrValue::cooperative").find(HIP_structs_regex) != std::string::npos) { + if (std::string("hipLaunchAttributeValue::priority").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "priority="); + roctracer::hip_support::detail::operator<<(out, v.priority); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchAttributeValue::cooperative").find(HIP_structs_regex) != std::string::npos) { std::operator<<(out, "cooperative="); roctracer::hip_support::detail::operator<<(out, v.cooperative); std::operator<<(out, ", "); } - if (std::string("hipKernelNodeAttrValue::accessPolicyWindow").find(HIP_structs_regex) != std::string::npos) { + if (std::string("hipLaunchAttributeValue::accessPolicyWindow").find(HIP_structs_regex) != std::string::npos) { std::operator<<(out, "accessPolicyWindow="); roctracer::hip_support::detail::operator<<(out, v.accessPolicyWindow); } @@ -3287,6 +3297,35 @@ inline static std::ostream& operator<<(std::ostream& out, const HIP_MEMSET_NODE_ std::operator<<(out, '}'); return out; } +inline static std::ostream& operator<<(std::ostream& out, const hipGraphInstantiateParams& v) +{ + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipGraphInstantiateParams::uploadStream").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "uploadStream="); + roctracer::hip_support::detail::operator<<(out, v.uploadStream); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::result_out").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "result_out="); + roctracer::hip_support::detail::operator<<(out, v.result_out); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::flags").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::errNode_out").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "errNode_out="); + roctracer::hip_support::detail::operator<<(out, v.errNode_out); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} inline static std::ostream& operator<<(std::ostream& out, const hipMemAllocationProp& v) { std::operator<<(out, '{'); @@ -3513,6 +3552,35 @@ inline static std::ostream& operator<<(std::ostream& out, const hipGraphNodePara std::operator<<(out, '}'); return out; } +inline static std::ostream& operator<<(std::ostream& out, const hipGraphEdgeData& v) +{ + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipGraphEdgeData::type").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::to_port").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "to_port="); + roctracer::hip_support::detail::operator<<(out, v.to_port); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::reserved").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::from_port").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "from_port="); + roctracer::hip_support::detail::operator<<(out, v.from_port); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} inline static std::ostream& operator<<(std::ostream& out, const hipDeviceProp_tR0000& v) { std::operator<<(out, '{'); @@ -4352,7 +4420,7 @@ inline static std::ostream& operator<<(std::ostream& out, const hipAccessPolicyW return out; } -inline static std::ostream& operator<<(std::ostream& out, const hipKernelNodeAttrValue& v) +inline static std::ostream& operator<<(std::ostream& out, const hipLaunchAttributeValue& v) { roctracer::hip_support::detail::operator<<(out, v); return out; @@ -4364,6 +4432,12 @@ inline static std::ostream& operator<<(std::ostream& out, const HIP_MEMSET_NODE_ return out; } +inline static std::ostream& operator<<(std::ostream& out, const hipGraphInstantiateParams& v) +{ + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + inline static std::ostream& operator<<(std::ostream& out, const hipMemAllocationProp& v) { roctracer::hip_support::detail::operator<<(out, v); @@ -4424,6 +4498,12 @@ inline static std::ostream& operator<<(std::ostream& out, const hipGraphNodePara return out; } +inline static std::ostream& operator<<(std::ostream& out, const hipGraphEdgeData& v) +{ + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + inline static std::ostream& operator<<(std::ostream& out, const hipDeviceProp_tR0000& v) { roctracer::hip_support::detail::operator<<(out, v); diff --git a/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h b/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h index 353ddc6ba4ca..7dfd39dd099f 100644 --- a/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h +++ b/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h @@ -785,6 +785,236 @@ inline static std::ostream& operator<<(std::ostream& out, const hsa_ext_images_1 std::operator<<(out, '}'); return out; } +inline static std::ostream& operator<<(std::ostream& out, const perf_sample_hosttrap_v1_t& v) +{ + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("perf_sample_hosttrap_v1_t::correlation_id").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "correlation_id="); + roctracer::hsa_support::detail::operator<<(out, v.correlation_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::timestamp").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "timestamp="); + roctracer::hsa_support::detail::operator<<(out, v.timestamp); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved1").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved0").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::hw_id").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hw_id="); + roctracer::hsa_support::detail::operator<<(out, v.hw_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hsa_support::detail::operator<<(out, v.reserved); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::chiplet").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "chiplet="); + roctracer::hsa_support::detail::operator<<(out, v.chiplet); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::wave_in_wg").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "wave_in_wg="); + roctracer::hsa_support::detail::operator<<(out, v.wave_in_wg); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_z").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_z="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_z); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_y").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_y="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_y); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_x").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_x="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_x); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::exec_mask").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "exec_mask="); + roctracer::hsa_support::detail::operator<<(out, v.exec_mask); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::pc").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "pc="); + roctracer::hsa_support::detail::operator<<(out, v.pc); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream& operator<<(std::ostream& out, const perf_sample_snapshot_v1_t& v) +{ + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("perf_sample_snapshot_v1_t::correlation_id").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "correlation_id="); + roctracer::hsa_support::detail::operator<<(out, v.correlation_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::timestamp").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "timestamp="); + roctracer::hsa_support::detail::operator<<(out, v.timestamp); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data2").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data2="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data2); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data1").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data1="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data1); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::hw_id").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hw_id="); + roctracer::hsa_support::detail::operator<<(out, v.hw_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::reserved").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hsa_support::detail::operator<<(out, v.reserved); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::chiplet").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "chiplet="); + roctracer::hsa_support::detail::operator<<(out, v.chiplet); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::wave_in_wg").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "wave_in_wg="); + roctracer::hsa_support::detail::operator<<(out, v.wave_in_wg); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_z").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_z="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_z); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_y").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_y="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_y); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_x").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_x="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_x); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::exec_mask").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "exec_mask="); + roctracer::hsa_support::detail::operator<<(out, v.exec_mask); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::pc").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "pc="); + roctracer::hsa_support::detail::operator<<(out, v.pc); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pcs_t& v) +{ + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pcs_t::handle").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pcs_configuration_t& v) +{ + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pcs_configuration_t::flags").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hsa_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::max_interval").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "max_interval="); + roctracer::hsa_support::detail::operator<<(out, v.max_interval); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::min_interval").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "min_interval="); + roctracer::hsa_support::detail::operator<<(out, v.min_interval); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::units").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "units="); + roctracer::hsa_support::detail::operator<<(out, v.units); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::method").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "method="); + roctracer::hsa_support::detail::operator<<(out, v.method); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pc_sampling_1_00_pfn_t& v) +{ + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_flush").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_flush="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_flush); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_stop").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_stop="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_stop); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_start").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_start="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_start); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_destroy").find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_destroy="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_destroy); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} inline static std::ostream& operator<<(std::ostream& out, const hsa_amd_vendor_packet_header_t& v) { std::operator<<(out, '{'); @@ -1360,6 +1590,36 @@ inline static std::ostream& operator<<(std::ostream& out, const hsa_ext_images_1 return out; } +inline static std::ostream& operator<<(std::ostream& out, const perf_sample_hosttrap_v1_t& v) +{ + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream& operator<<(std::ostream& out, const perf_sample_snapshot_v1_t& v) +{ + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pcs_t& v) +{ + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pcs_configuration_t& v) +{ + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream& operator<<(std::ostream& out, const hsa_ven_amd_pc_sampling_1_00_pfn_t& v) +{ + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + inline static std::ostream& operator<<(std::ostream& out, const hsa_amd_vendor_packet_header_t& v) { roctracer::hsa_support::detail::operator<<(out, v); diff --git a/third_party/amd/backend/include/roctracer/hsa_prof_str.h b/third_party/amd/backend/include/roctracer/hsa_prof_str.h index 28b2bf54d7c4..3747659f7924 100644 --- a/third_party/amd/backend/include/roctracer/hsa_prof_str.h +++ b/third_party/amd/backend/include/roctracer/hsa_prof_str.h @@ -22,9 +22,9 @@ /* HSA API tracing primitives 'CoreApi', header 'hsa.h', 125 funcs - 'AmdExt', header 'hsa_ext_amd.h', 68 funcs + 'AmdExt', header 'hsa_ext_amd.h', 70 funcs 'ImageExt', header 'hsa_ext_image.h', 13 funcs - 'AmdExt', header 'hsa_api_trace.h', 68 funcs + 'AmdExt', header 'hsa_api_trace.h', 70 funcs */ #ifndef HSA_PROF_STR_H_ @@ -229,24 +229,26 @@ enum hsa_api_id_t { HSA_API_ID_hsa_amd_vmem_retain_alloc_handle = 190, HSA_API_ID_hsa_amd_vmem_get_alloc_properties_from_handle = 191, HSA_API_ID_hsa_amd_agent_set_async_scratch_limit = 192, + HSA_API_ID_hsa_amd_queue_get_info = 193, + HSA_API_ID_hsa_amd_vmem_address_reserve_align = 194, /* block: ImageExt API */ - HSA_API_ID_hsa_ext_image_get_capability = 193, - HSA_API_ID_hsa_ext_image_data_get_info = 194, - HSA_API_ID_hsa_ext_image_create = 195, - HSA_API_ID_hsa_ext_image_import = 196, - HSA_API_ID_hsa_ext_image_export = 197, - HSA_API_ID_hsa_ext_image_copy = 198, - HSA_API_ID_hsa_ext_image_clear = 199, - HSA_API_ID_hsa_ext_image_destroy = 200, - HSA_API_ID_hsa_ext_sampler_create = 201, - HSA_API_ID_hsa_ext_sampler_destroy = 202, - HSA_API_ID_hsa_ext_image_get_capability_with_layout = 203, - HSA_API_ID_hsa_ext_image_data_get_info_with_layout = 204, - HSA_API_ID_hsa_ext_image_create_with_layout = 205, + HSA_API_ID_hsa_ext_image_get_capability = 195, + HSA_API_ID_hsa_ext_image_data_get_info = 196, + HSA_API_ID_hsa_ext_image_create = 197, + HSA_API_ID_hsa_ext_image_import = 198, + HSA_API_ID_hsa_ext_image_export = 199, + HSA_API_ID_hsa_ext_image_copy = 200, + HSA_API_ID_hsa_ext_image_clear = 201, + HSA_API_ID_hsa_ext_image_destroy = 202, + HSA_API_ID_hsa_ext_sampler_create = 203, + HSA_API_ID_hsa_ext_sampler_destroy = 204, + HSA_API_ID_hsa_ext_image_get_capability_with_layout = 205, + HSA_API_ID_hsa_ext_image_data_get_info_with_layout = 206, + HSA_API_ID_hsa_ext_image_create_with_layout = 207, - HSA_API_ID_DISPATCH = 206, - HSA_API_ID_NUMBER = 207, + HSA_API_ID_DISPATCH = 208, + HSA_API_ID_NUMBER = 209, }; /* Declarations of APIs intended for use only by tools. */ typedef void (*hsa_amd_queue_intercept_packet_writer)(const void*, uint64_t); @@ -261,9 +263,9 @@ struct hsa_api_data_t { uint32_t phase; union { uint64_t uint64_t_retval; - uint32_t uint32_t_retval; - hsa_signal_value_t hsa_signal_value_t_retval; hsa_status_t hsa_status_t_retval; + hsa_signal_value_t hsa_signal_value_t_retval; + uint32_t uint32_t_retval; }; union { /* block: CoreApi API */ @@ -1236,6 +1238,18 @@ struct hsa_api_data_t { hsa_agent_t agent; size_t threshold; } hsa_amd_agent_set_async_scratch_limit; + struct { + hsa_queue_t* queue; + hsa_queue_info_attribute_t attribute; + void* value; + } hsa_amd_queue_get_info; + struct { + void** va; + size_t size; + uint64_t address; + uint64_t alignment; + uint64_t flags; + } hsa_amd_vmem_address_reserve_align; /* block: ImageExt API */ struct { @@ -2888,6 +2902,24 @@ inline std::ostream& operator<< (std::ostream& out, const hsa_api_data_pair_t& d out << ") = " << api_data.hsa_status_t_retval; break; } + case HSA_API_ID_hsa_amd_queue_get_info: { + out << "hsa_amd_queue_get_info("; + out << api_data.args.hsa_amd_queue_get_info.queue << ", "; + out << api_data.args.hsa_amd_queue_get_info.attribute << ", "; + out << api_data.args.hsa_amd_queue_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_address_reserve_align: { + out << "hsa_amd_vmem_address_reserve_align("; + out << api_data.args.hsa_amd_vmem_address_reserve_align.va << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.size << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.address << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.alignment << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } /* block: ImageExt API */ case HSA_API_ID_hsa_ext_image_get_capability: { diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index a7395f86dc50..486fd60293da 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -21,8 +21,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ -#define TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" + // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on @@ -40,10 +41,4 @@ #define GET_OP_CLASSES #include "amd/include/Dialect/TritonAMDGPU/IR/Ops.h.inc" -namespace mlir { -namespace triton { -namespace amdgpu {} // namespace amdgpu -} // namespace triton -} // namespace mlir - -#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 31a43acd2f89..c0aa08421bdd 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index d5956cf7a33c..c0c18b07e907 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect { }]; let dependentDialects = []; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 538e31378fe8..b2f857e40a7f 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -31,10 +31,12 @@ include "mlir/IR/EnumAttr.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" + class TT_AMDGPU_Op traits = []> : Op { } @@ -44,6 +46,74 @@ class TT_AMDGPU_Op traits = []> : // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +//===----------------------------------------------------------------------===// +// ExtractSliceOp +//===----------------------------------------------------------------------===// + +def ExtractSliceOp + : TT_AMDGPU_Op<"extract_slice", [Pure]> { + let summary = "extract slice operation"; + let description = [{ + The "extract_slice" operation enables extracting a slice of a tensor in + registers. + + The "extract_slice" operation supports the following arguments: + + * source: the base tensor on which to create a view tensor + * offsets: offsets into the base tensor at which to create the view + + Example 1: + + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> + %1 = ttg.convert_layout %0 : tensor<128x128xf16, #blocked> + -> tensor<128x128xf16, #blocked1> + // create a slice of base tensor %1 with static offsets + %2 = amdgpu.extract_slice %0 [0, 0] : + tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1> + ``` + + Example 1 shows how "extract_slice" operation may be used. In this example a + new slice of 128x32 is created. "extract_slice" works on tensors with layout + where the desired slice has the same layout as the source tensor. + "%0" cannot be sliced directly as the resulting slice cannot have the same + layout as "%0". Therefore it needs to be converted to a layout suitable + for slicing. "#blocked1" layout is appropriate for this as it keeps the + sizePerThread the same thus keeping coalescing properties the same. + In order to utilize all threads in a warp, "threadsPerWarp" is set to + [16,4] for this new layout. This layout conversion carried out before + using "extract_slice" ensures slicing still uses all threads efficiently. The + size of the slice is determined by the result type. + }]; + + let arguments = (ins AnyRankedTensor:$source, + DenseI64ArrayAttr:$static_offsets); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build a ExtractSliceOp with static offsets and the same result type + OpBuilder<(ins "RankedTensorType":$resultType, + "Value":$source, + "ArrayRef": $static_offsets)>, + ]; + + let extraClassDeclaration = [{ + std::array getArrayAttrMaxRanks() { + unsigned rank = getSource().getType().getRank(); + return {rank, rank, rank}; + } + }]; + + let assemblyFormat = [{ + $source $static_offsets attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} + def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { let summary = "A placeholder op for instruction scheduling hints within a basic block"; let description = [{ @@ -57,7 +127,29 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { interleave for better instruction level parallelism. }]; - let assemblyFormat = [{attr-dict}]; + let arguments = (ins + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + BoolAttr:$isBufferLoadsAEnabled, + BoolAttr:$isBufferLoadsBEnabled, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins), [{ + auto ctx = $_state.getContext(); + auto noneType = NoneType::get(ctx); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); + build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, false, false, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; } // diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h index ac37aab817fa..0c60759a8cb7 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h @@ -20,8 +20,8 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_ -#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_GCN_FORMAT_H_ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ #include "mlir/IR/Value.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -31,10 +31,13 @@ #include namespace mlir { + class ConversionPatternRewriter; class Location; -namespace triton { +} // namespace mlir + +namespace mlir::triton { using llvm::StringRef; class GCNInstr; @@ -397,7 +400,6 @@ struct GCNMemInstr : public GCNInstrBase { } }; -} // namespace triton -} // namespace mlir +} // namespace mlir::triton -#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index 67ff40d5b9bc..2043e124beee 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -1,5 +1,5 @@ -#ifndef TRITONAMDGPU_CONVERSION_PASSES_H -#define TRITONAMDGPU_CONVERSION_PASSES_H +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -13,12 +13,16 @@ namespace mlir { class ModuleOp; template class OperationPass; -namespace triton { +} // namespace mlir + +namespace mlir::triton { #define GEN_PASS_DECL #include "TritonAMDGPUToLLVM/Passes.h.inc" -namespace AMD { +} // namespace mlir::triton + +namespace mlir::triton::AMD { std::unique_ptr> createDecomposeUnsupportedConversionsPass(StringRef targetArch); @@ -29,21 +33,24 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch); /// @return created pass std::unique_ptr> createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); -} // namespace AMD +} // namespace mlir::triton::AMD + +namespace mlir::triton { std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); -std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); std::unique_ptr> -createInsertInstructionSchedHintsPass(); +createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPUInsertInstructionSchedHintsPass(); +std::unique_ptr> +createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, + int32_t numStages, + StringRef variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace triton - -} // namespace mlir +} // namespace mlir::triton -#endif +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index ccb2b1898f42..3bcdc77022bd 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -34,7 +34,6 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod "mlir::gpu::GPUDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect", - "mlir::tensor::TensorDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", "mlir::ROCDL::ROCDLDialect"]; @@ -49,27 +48,38 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Builtin Func to LLVM"; - let constructor = "mlir::triton::createConvertBuiltinFuncToLLVMPass()"; + let constructor = "mlir::triton::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)"; let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let options = [ + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, + ]; } -def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; } -def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ - Option<"variant", "variant", "std::string", /*default*/"\"default\"", + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942">, + Option<"numStages", "num_stages", "int32_t", /*default*/"2", + "number of pipeline stages">, + Option<"variant", "variant", "std::string", /*default*/"\"none\"", "instruction scheduling variant">, ]; } diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h new file mode 100644 index 000000000000..cd9407ed2b34 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -0,0 +1,14 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index a49e442d3984..6cf5548053b5 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H -#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ #include "llvm/ADT/StringRef.h" @@ -19,6 +19,17 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); +// Here is a partial definition of DppCtrl enums. For the complete definition, +// please check: +// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 +enum class DppCtrl : uint32_t { + QUAD_PERM_FIRST = 0, + ROW_SHL0 = 0x100, + ROW_SHR0 = 0x110, + BCAST15 = 0x142, + BCAST31 = 0x143 +}; + } // namespace mlir::triton::AMD -#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ diff --git a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h index 121bb617265f..57c8b6a58724 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h @@ -1,5 +1,5 @@ -#ifndef TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_MFMAGROUP_H_ -#define TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_MFMAGROUP_H_ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "llvm/ADT/DenseMap.h" @@ -91,4 +91,4 @@ class MfmaInsn { }; } // namespace mlir -#endif // TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_MFMAGROUP_H_ +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 841137887ba0..630a1e903562 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -1,12 +1,14 @@ -#ifndef TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_PASSES_H_ -#define TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_PASSES_H_ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); +std::unique_ptr createTritonAMDGPUStreamPipelinePass(int numStages = 2, + int prefetch = 0); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), @@ -23,9 +25,11 @@ std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); +std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "TritonAMDGPUTransforms/Passes.h.inc" } // namespace mlir -#endif +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index d59935e796fa..6bee6da5fb45 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -3,7 +3,7 @@ include "mlir/Pass/PassBase.td" -def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { +def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ @@ -11,14 +11,17 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir tile }]; - let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; + let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; - let dependentDialects = []; + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"numStages", "num_stages", "int32_t", /*default*/"2", - "Number of Pipeline stages"> + "Number of Pipeline stages">, + Option<"prefetch", "prefetch", + "int32_t", /*default*/"0", + "Enable prefetch from shared memory"> ]; } @@ -111,4 +114,14 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", " let dependentDialects = []; } +def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> { + let summary = "Convert memory operations to buffer operations"; + + let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible"; + + let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()"; + + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; +} + #endif diff --git a/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h b/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h index fbfa235fc6bb..0e8b7a624010 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h @@ -4,8 +4,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ -#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ #include "mlir/Transforms/DialectConversion.h" @@ -35,4 +35,4 @@ class TritonGPUConversionTarget : public ConversionTarget { } // namespace mlir -#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/amd/language/hip/libdevice.py b/third_party/amd/language/hip/libdevice.py index 6b40a40c9cd7..a69d4406cc12 100644 --- a/third_party/amd/language/hip/libdevice.py +++ b/third_party/amd/language/hip/libdevice.py @@ -66,6 +66,13 @@ def exp(arg0, _builder=None): }, is_pure=True, _builder=_builder) +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + @core.extern def fast_dividef(arg0, arg1, _builder=None): return core.extern_elementwise("", "", [arg0, arg1], { diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index a82a77e9f57e..0e2a9304ebfe 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -24,6 +24,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" @@ -45,5 +48,87 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::amdgpu { + +LogicalResult ExtractSliceOp::verify() { + auto srcTy = getSource().getType(); + auto srcLayout = srcTy.getEncoding(); + auto srcElementType = getElementTypeOrSelf(srcTy); + auto resultTy = getResult().getType(); + auto resultLayout = resultTy.getEncoding(); + auto resultElementType = getElementTypeOrSelf(resultTy); + + if (srcElementType != resultElementType) { + return emitError("result element type must match source element type"); + } + if (srcLayout != resultLayout) { + return emitError("result layout must match source layout"); + } + if (srcTy.getRank() != resultTy.getRank()) { + return emitError("result rank must be equal to source rank"); + } + if (srcTy.getRank() != 2) { + return emitError("currently only 2D tensors are supported"); + } + + auto srcShape = srcTy.getShape(); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + + // ExtractSlice only supports slicing where offsets and sizes are multiples of + // shapePerCTATile. This condition ensures that slice has the same layout as + // the original tensor. + + auto offsets = getStaticOffsets(); + if (offsets.size() != 2) { + return emitError("invalid offset shape ") << offsets; + } + + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + auto resultDimSize = resultTy.getDimSize(i); + auto srcDimSize = srcTy.getDimSize(i); + if (resultDimSize == 0) { + return emitError("result tensor dimension size zero at dimension ") << i; + } + if (srcDimSize == 0) { + return emitError("source tensor dimension size zero at dimension ") << i; + } + if (resultDimSize > srcDimSize) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + if (offsets[i] + resultDimSize > srcDimSize) { + return emitError("invalid offset ") + << offsets[i] << " at dimension " << i; + } + sizes.push_back(resultDimSize); + } + + if (sizes[0] % shapePerCTATile[0] != 0 || + sizes[1] % shapePerCTATile[1] != 0) { + return emitError() << "sizes [" << sizes + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + if (offsets[0] % shapePerCTATile[0] != 0 || + offsets[1] % shapePerCTATile[1] != 0) { + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + return success(); +} +} // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index e6da8f28777e..4aebabc0a275 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp + ExtractSliceOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp new file mode 100644 index 000000000000..eb2cde1a93f2 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -0,0 +1,143 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +// clang-format off +//===--------------------------------------------------------------------------------===// +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +// # WO # W1 # | # +// # # # | # +// # # # # # | # +// # W2 # W3 # .... | # +// # # # | SkipElems # +// # # # # # | # +// # | # +// # Slice | # +// # . / \ | # +// # . / \ | # +// # . / \| # +// # # # # # # # +// # # W0 # W1 # # +// # # # # # +// # # # # # # tensorStride # +// # # W2 # W3 # --------------------------------# +// # # # # # +// # # # # # # # +// # tensorStride # W0 # W1 # # +// # ---------------------------------- # # # # +// # # # # # # # +// # # W2 # W3 # # +// # # # # # +// # # # # # # ---> lastIdx # +// # . # +// # . # +// # . # +// # # +// # # +// # # +// # # +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +//===--------------------------------------------------------------------------------===// +// clang-format on + +namespace { +struct ExtractSliceOpConversion + : public ConvertOpToLLVMPattern { + explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) { + } + + LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto srcTy = cast(op.getSource().getType()); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultTy = cast(op.getType()); + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); + auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); + auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); + auto totalSizePerThread = product(sizePerThread); + auto order = triton::gpu::getOrder(srcLayout); + + // Calculate valid total number of workers in each dimension + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + + // Rank == 2 checked in the verifier + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + sizes.push_back(resultTy.getDimSize(i)); + } + + auto offsets = op.getStaticOffsets(); + + // Calculate offsets and sizes in terms of CTA units. + std::array CTAOffsets{offsets[0] / shapePerCTATile[0], + offsets[1] / shapePerCTATile[1]}; + std::array CTASizes{sizes[0] / shapePerCTATile[0], + sizes[1] / shapePerCTATile[1]}; + std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], + srcShape[1] / shapePerCTATile[1]}; + + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + auto skipElems = CTAOffsets[order[1]] * + (elemsPerThread[order[0]] * sizePerThread[order[1]]) + + CTAOffsets[order[0]] * totalSizePerThread; + auto tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; + auto lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * sizePerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; + + assert(lastIdx <= vals.size()); + + SmallVector resultVals; + for (int i = skipElems; i < lastIdx; i += tensorStride) { + for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { + assert(i < lastIdx); + resultVals.push_back(vals[i]); + } + } + Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + + rewriter.replaceOp(op, ret); + return success(); + } + + LogicalResult + matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = op.getSource().getType(); + if (isa( + op.getSource().getType().getEncoding())) { + return processLayout(op, adaptor, rewriter); + } + return failure(); + } +}; +} // namespace + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index 5d172fea9cfa..c7c2f56d31de 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -1,9 +1,10 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" namespace mlir::triton::AMD { void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - // TODO: Insert TrtionAMDGPU dialect patterns. + populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index be009af4d1f0..37bdb8fe99ca 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -133,7 +133,7 @@ Type BufferEmitter::getBufferOpType(Type type) { // will be bitcast-able to the original type. So if the types // ended up different, we simply have to emit a `bitcastOp` to convert Type bufferType = type; - if (bufferVecSize != vecSize) + if (bufferVecSize != vecSize || bufferElementType != elementType) bufferType = VectorType::get(bufferVecSize, bufferElementType); if (bufferVecSize == 1) bufferType = getElementTypeOrSelf(bufferType); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h index ad6d46ff78a0..0bef7a644729 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H -#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ #include "TargetInfo.h" #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" @@ -90,4 +90,4 @@ struct BufferEmitter { } // namespace mlir::LLVM::AMD -#endif // TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 18364b67e1bd..de92fa01441a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -6,26 +6,23 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -namespace mlir { -namespace triton { +namespace mlir::triton { #define GEN_PASS_DEF_CONVERTBUILTINFUNCTOLLVM #include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir +} // namespace mlir::triton using namespace mlir; namespace { -class CallOpConversion : public mlir::RewritePattern { +class CallOpConversion : public OpRewritePattern { public: - CallOpConversion(mlir::MLIRContext *context) - : mlir::RewritePattern(LLVM::CallOp::getOperationName(), 1, context) {} + CallOpConversion(mlir::MLIRContext *context, bool ftz) + : OpRewritePattern(context, 1), ftz(ftz) {} LogicalResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(LLVM::CallOp callOp, mlir::PatternRewriter &rewriter) const override { - auto callOp = cast(op); if (isPredicatedLoad(callOp)) { return convertPredicatedLoad(callOp, rewriter); } else if (isPredicatedStore(callOp)) { @@ -102,12 +99,10 @@ class CallOpConversion : public mlir::RewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, afterStore); rewriter.setInsertionPointToStart(trueBlock); - /* - | vialatile | non-tmp | gcn instr gfx94 - LLVM::StoreOp | 0 | 0 | (cg) global store - | 0 | 1 | (cs) global store nt - | 1 | 0/1 | (wt) global store sc0 sc1 - */ + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::StoreOp | 0 | 0 | (cg) global store + // | 0 | 1 | (cs) global store nt + // | 1 | 0/1 | (wt) global store sc0 sc1 bool vialatileFlag = isPredicatedStoreWT(callOp); bool nonTmpFlag = isPredicatedStoreCS(callOp); auto storeOp = rewriter.create( @@ -139,12 +134,10 @@ class CallOpConversion : public mlir::RewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, falseBlock); rewriter.setInsertionPointToStart(trueBlock); - /* - | vialatile | non-tmp | gcn instr gfx94 - LLVM::LoadOp | 0 | 0 | (ca) global load - | 0/1 | 1 | (cg) global load nt - | 1 | 0 | (cv) flat load sc0 sc1 - */ + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::LoadOp | 0 | 0 | (ca) global load + // | 0/1 | 1 | (cg) global load nt + // | 1 | 0 | (cv) flat load sc0 sc1 bool vialatileFlag = isPredicatedLoadCV(callOp); bool nonTmpFlag = isPredicatedLoadCG(callOp); auto loadOp = rewriter.create( @@ -195,6 +188,18 @@ class CallOpConversion : public mlir::RewritePattern { LLVM::FastmathFlagsAttr defaultFlags{}; replacementOp = rewriter.create( loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + } else if (calleeName == "__triton_hip_fast_expf") { + assert(operands.size() == 1); + assert(operands[0].getType().getIntOrFloatBitWidth() == 32); + const double log2e = 1.4426950408889634; + LLVM::FastmathFlagsAttr defaultFlags{}; + auto mulOp = rewriter.create( + loc, rewriter.getF32Type(), operands[0], + LLVM::createConstantF32(loc, rewriter, log2e), defaultFlags); + const char *intrinsic = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + + replacementOp = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsic, returnType, mulOp->getResult(0)); } if (replacementOp) { @@ -204,11 +209,16 @@ class CallOpConversion : public mlir::RewritePattern { return mlir::failure(); } + +private: + bool ftz; }; struct ConvertBuiltinFuncToLLVM : public triton::impl::ConvertBuiltinFuncToLLVMBase< ConvertBuiltinFuncToLLVM> { + explicit ConvertBuiltinFuncToLLVM(bool ftz) { this->ftz = ftz; } + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -217,7 +227,7 @@ struct ConvertBuiltinFuncToLLVM config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context, this->ftz); if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config) .failed()) { @@ -226,14 +236,13 @@ struct ConvertBuiltinFuncToLLVM } }; -} // anonymous namespace +} // namespace -namespace mlir { -namespace triton { +namespace mlir::triton { -std::unique_ptr> createConvertBuiltinFuncToLLVMPass() { - return std::make_unique(); +std::unique_ptr> +createConvertBuiltinFuncToLLVMPass(bool ftz) { + return std::make_unique(ftz); } -} // namespace triton -} // namespace mlir +} // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index b6a514f450cc..abd86dc03301 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp SchedInstructions.cpp + UpcastMXFPToLLVM.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 953b01dab08a..208483beb8fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -9,6 +9,7 @@ using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandMFMA { @@ -50,7 +51,11 @@ struct LocalLoadOpConversion } private: - // shared -> dot_operand if the result layout is mfma + /// Lower ttg.local_load in dot operand layout if the operand parent layout is + /// MFMA or WMMA. + /// + /// \returns value with packed loaded values or empty value if this local_load + /// is not supproted. Value lowerSharedToDotOperandMMA( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, const LLVMTypeConverter *typeConverter, @@ -104,61 +109,13 @@ struct LocalLoadOpConversion isOuter = K == 1; Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, dotOperandLayout, isOuter); + if (!res) + return failure(); rewriter.replaceOp(op, res); return success(); } }; -struct ConvertLayoutOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = op.getSrc(); - Value dst = op.getResult(); - auto srcTy = cast(src.getType()); - auto dstTy = cast(dst.getType()); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (isa(srcLayout) && - isa(dstLayout)) { - return lowerMfmaToDotOperand(op, adaptor, rewriter); - } - return failure(); - } - -private: - LogicalResult - lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - if (isMfmaToDotShortcut(srcTy, dstTy)) { - // vecSize is an number of sequential elements stored by one thread - // - For MFMA encoding (encoding of the result tensor of dot - // operation) it is 4 - // - For MFMA operand encoding it is - // dotOperandEncoding::kWidth, - // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack - // = 1) - // - // For cases where these two values are equal MFMA and MFMA operand - // layouts are the same. - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - Value view = - packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } - return failure(); - } -}; } // namespace namespace mlir::triton::AMD { @@ -166,7 +123,6 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 03b7c56b7e6b..46d60e2c5da3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -82,17 +82,15 @@ bool isKMajor(llvm::ArrayRef order, int opIdx) { return order[0] == kdim; } -/** - * @brief checks that swizzle pattern fits into one warp block - * and block size is a multiple of swizzle size along non-K dimension - * - * @param sharedLayout - * @param opIdx operand id 0 or 1 - * @param reps number of repetitions: [non-k, k] or [batch, non-k, k] - * @param elemsPerInstr one instruction size - * @param warpsPerBlockNonK number of warps along non-k Dim - * @return bool - */ +/// Checks that swizzle pattern fits into one warp block +/// and block size is a multiple of swizzle size along non-K dimension +/// +/// \param sharedLayout +/// \param opIdx operand id 0 or 1 +/// \param reps number of repetitions: [non-k, k] or [batch, non-k, k] +/// \param elemsPerInstr one instruction size +/// \param warpsPerBlockNonK number of warps along non-k Dim +/// \returns bool bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout, int opIdx, const ArrayRef reps, const ArrayRef elemsPerInstr, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h index 0db193e1c102..1b0e3b2df003 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_SHARED_TO_DOT_OPERAND_MATRIXCORE_H -#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_SHARED_TO_DOT_OPERAND_MATRIXCORE_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ #include "Utility.h" @@ -13,18 +13,16 @@ Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc, bool isSwizzled(gpu::SharedEncodingAttr layout); -/** - * @brief swizzling tensor element indexes according pattern encoded in - * SharedEncodingAttr - * - * @param rewriter - * @param loc - * @param row row of target tensor element related to the start of smemObj - * @param col col of target tensor element related to the start of smemObj - * @param smemObj shared memory object, contains info about tensor in LDS - * @param attr layout attribute, contains swizzling info - * @return swizzled row, col indexes in tensor notation - */ +/// Swizzling tensor element indexes according pattern encoded in +/// SharedEncodingAttr +/// +/// \param rewriter +/// \param loc +/// \param row row of target tensor element related to the start of smemObj +/// \param col col of target tensor element related to the start of smemObj +/// \param smemObj shared memory object, contains info about tensor in LDS +/// \param attr layout attribute, contains swizzling info +/// \returns swizzled row, col indexes in tensor notation std::pair swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, Value col, SharedMemoryObject smemObj, @@ -61,4 +59,4 @@ llvm::SmallVector computeOffsetsBType( } // namespace mlir::triton::AMD -#endif +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index c8df2ac99355..e55d87cb9434 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -32,43 +33,41 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandMFMA { -/** - * @brief This function maps particular load of mfma dot operand to element - * indexes(row, col) - * - * Whole tensor is broken into "blocks" of warps along "non-K" axis. - * One block could be processed by multiple warps. - * One warp works on a piece of tensor size elemsPerInstr[0] x K. - * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x - * elemsPerInstr[1]. - * - * Total offset of element is a sum of following values: - * 1. Offset of warp-block in tensor - * 2. Offset of warp inside one warp-block - * 3. Offset of tile in one warp - * 4. Offset of one lane data in a tile - * 5. Offset of particular element of tensor processed by one lane - * - * This function computes these offsets for axies independently - * Note that this function returns the offsets of elements in the first - * warp-block. The offsets of elements in later warp-blocks can be computed - * by adding a constant stride to the xor-ed offsets of elements in the - * first warp-block. - * - * @param rewriter - * @param loc - * @param elemsPerInstr operand tile shape consumed by one MFMA instruction - * @param warpId id component of 2d warp grid along non-K axis - * @param laneId lane id in warp [0..63] - * @param numOfElems number of elements accessed by thread per repetition - * @param reps number of instructions repetition to fully cover dot operand - * @param smemStrides strides in LDS tensor - * @param loadVecSize number of elements loaded by one operation - * @param iNonKDim non-K dimension size of one MFMA instruction - * @param iKDim K dimension size of one MFMA instruction - * @return vector (i-th element corresponds to i-th load instruction) of - * 2-element vectors(tensor row and col). - */ +/// This function maps particular load of mfma dot operand to element +/// indexes(row, col) +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp-block in tensor +/// 2. Offset of warp inside one warp-block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axies independently +/// Note that this function returns the offsets of elements in the first +/// warp-block. The offsets of elements in later warp-blocks can be computed +/// by adding a constant stride to the xor-ed offsets of elements in the +/// first warp-block. +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one MFMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension size of one MFMA instruction +/// \param iKDim K dimension size of one MFMA instruction +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). llvm::SmallVector> computeTensorElemMappingInBlock( ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, @@ -92,9 +91,9 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value laneVOffset = urem(laneId, nonKDim); Value laneHOffset; - if (iNonKDim == 32) + if (iNonKDim == 32) { laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); - else { + } else { // In this configuration warp contains 16 copies of same data if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { laneHOffset = i32_val(0); @@ -126,17 +125,18 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { return srcEncoding.getMaxPhase() > 1; } -// Computes offsets for operand B or transposed operand A -// @param rewriter -// @param loc -// @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA -// instruction -// @param warpId warp id for the "non K" axis -// @param laneId lane id in warp [0..63] -// @param warpsPerBlock number of warps per horizontal axis -// @param numOfElems number of elements accessed by threads per repetition -// @param reps number of instructions repretition to fully cover dot operand -// @param cSwizzleOffset +/// Computes offsets for operand B or transposed operand A +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA +/// instruction +/// \param warpId warp id for the "non K" axis +/// \param laneId lane id in warp [0..63] +/// \param warpsPerBlock number of warps per horizontal axis +/// \param numOfElems number of elements accessed by threads per repetition +/// \param reps number of instructions repretition to fully cover dot operand +/// \param cSwizzleOffset llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, @@ -197,7 +197,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread) { assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); - auto aTensorTy = cast(tensor.getType()); + auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); auto rank = shape.size(); int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; @@ -231,6 +231,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, mfmaInstrK = elemsPerInstr[kDimIdx]; } + if (mfmaInstrNonK > shape[nonKDimIdx] || mfmaInstrK > shape[kDimIdx]) { + // This pattern does not support cases tensor shape is smaller than + // one instruction size, it will be processed by LinearLayout converter + return Value(); + } + auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx); auto numRepNonK = numReps[nonKDimIdx]; auto numRepK = numReps[kDimIdx]; @@ -330,6 +336,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -340,7 +347,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; @@ -357,6 +363,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = mfmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index b60c86e1a3a5..8d5bc669e1eb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -32,39 +33,37 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandWMMA { -/** - * @brief Following functions maps particular load of wmma dot operand to - * element indexes(row, col). For each WMMA generation separate function is - * used. - * - * Whole tensor is broken into "blocks" of warps along "non-K" axis. - * One block could be processed by multiple warps. - * One warp works on a piece of tensor size elemsPerInstr[0] x K. - * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x - * elemsPerInstr[1]. - * - * Total offset of element is a sum of following values: - * 1. Offset of warp block in tensor - * 2. Offset of warp inside one warp block - * 3. Offset of tile in one warp - * 4. Offset of one lane data in a tile - * 5. Offset of particular element of tensor processed by one lane - * - * This function computes these offsets for axes independently - * - * @param rewriter - * @param loc - * @param elemsPerInstr operand tile shape consumed by one WMMA instruction - * @param warpId id component of 2d warp grid along non-K axis - * @param laneId lane id in warp [0..63] - * @param numOfElems number of elements accessed by thread per repetition - * @param reps number of instructions repetition to fully cover dot operand - * @param smemStrides strides in LDS tensor - * @param loadVecSize number of elements loaded by one operation - * @param iNonKDim non-K dimension of dot operand - * @return vector (i-th element corresponds to i-th load instruction) of - * 2-element vectors(tensor row and col). - */ +/// Following functions maps particular load of wmma dot operand to +/// element indexes(row, col). For each WMMA generation separate function is +/// used. +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp block in tensor +/// 2. Offset of warp inside one warp block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axes independently +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one WMMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension of dot operand +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). llvm::SmallVector> computeTensorElemMappingInBlockWmma1( ConversionPatternRewriter &rewriter, Location loc, @@ -151,7 +150,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, assert(wmmaLayout.getMNKDimPerInstr()[nonKDimIdx] == 16); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - auto aTensorTy = cast(tensor.getType()); + auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); auto sharedLayout = cast(aTensorTy.getEncoding()); auto order = sharedLayout.getOrder(); @@ -212,6 +211,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -221,7 +221,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; loadOffset = add(loadOffset, batchOffset); @@ -237,6 +236,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = wmmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index cece47227ea0..bbacde54b041 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -11,12 +11,10 @@ #include using namespace mlir; -namespace mlir { -namespace triton { +namespace mlir::triton { #define GEN_PASS_DEF_DECOMPOSEUNSUPPORTEDAMDCONVERSIONS #include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir +} // namespace mlir::triton namespace { @@ -38,12 +36,12 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMfmaToDotShortcut); + auto isShortcut = + mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); + + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); - /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` - /* -------------------------------- */ mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { OpBuilder builder(cvtOp); auto srcType = cvtOp.getSrc().getType(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 204d54894d3b..54e3c6ac8527 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "Utility.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; @@ -261,16 +261,22 @@ struct DotOpMFMAConversionHelper { Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + Type elemtTy = elemTyA; + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(), + maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(), + elemtTy); + rewriter.replaceOp(op, res); return success(); } - /** - * @brief extract vector from rawElems based on kWidth and kBase - * rawElems is a vector of kWidth elements. We need to prepare vector(s) of - * kBase elements for each mfma instruction - */ + /// Extract vector from rawElems based on kWidth and kBase + /// rawElems is a vector of kWidth elements. We need to prepare vector(s) of + /// kBase elements for each mfma instruction SmallVector extractOperands(Value rawElems, int kWidth, int kBase, Type type) const { int kpack = kWidth / kBase; @@ -286,8 +292,9 @@ struct DotOpMFMAConversionHelper { // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type auto cast = bitcast(val, i16_ty); vec = insert_element(vecTy, vec, cast, i32_val(elemId)); - } else + } else { vec = insert_element(vecTy, vec, val, i32_val(elemId)); + } } if (type.getIntOrFloatBitWidth() == 8) { if (4 == kBase) @@ -295,16 +302,15 @@ struct DotOpMFMAConversionHelper { results.push_back(bitcast(vec, i32_ty)); if (8 == kBase) results.push_back(bitcast(vec, i64_ty)); - } else + } else { results.push_back(vec); + } } return results; } - /** - * @brief Converts dot operand structure to value table and converts types - * appropriate for mfma instructions - */ + /// Converts dot operand structure to value table and converts types + /// appropriate for mfma instructions SmallVector getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type) const { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 5a003f768833..0042cf89e93b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -22,6 +22,7 @@ */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -325,6 +326,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, Type structTy = LLVM::LLVMStructType::getLiteral( wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 47d5fbb3550d..716a93865ddd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -21,8 +21,14 @@ typedef std::function(Location, ConversionPatternRewriter &, ConverterT; namespace { -// ROCM utility functions for data type conversion -/* ----- FP8E5M2 ------ */ +//===-------------------------------------------===// +/// ROCM utility functions for data type conversion +//===-------------------------------------------===// + +//===----------------===// +/// FP8E5M2 +//===----------------===// + // This data-type is the standard FP8E5M2 format // NVIDIA GPU supports it natively but we don't have hardware native // support on MI300. @@ -221,6 +227,7 @@ Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, assert(v.size() == 2); return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8"); } + // Depend on whether we focus more on performance, we may skip // the processing of submornal values static Value Fp16_to_Fp8E5M2FNUZ_oneValue(Location loc, @@ -537,7 +544,47 @@ static SmallVector Bf16_to_Fp8E5M2(Location loc, extract_element(i8_ty, fp8x4Vec, i32_val(3))}; } -// ROCM type conversion between fp8 and bf16 +//===-----------------------------------------===// +/// ROCM type conversion between fp8 and bf16 +//===-----------------------------------------===// + +// fp8e4m3fn to bf16 +static SmallVector Fp8E4M3FN_to_Bf16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value a0 = undef(fp8x4VecTy); + a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0)); + a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1)); + a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(2)); + a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3)); + a0 = bitcast(a0, i32_ty); + + Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff)); + b0 = lshr(i32_ty, b0, i32_val(4)); + + Value c0 = shl(i32_ty, b0, i32_val(16)); + Value c1 = and_(i32_ty, b0, i32_val(0xFFFF0000)); + c0 = bitcast(c0, f32_ty); + c1 = bitcast(c1, f32_ty); + + Value d0 = fmul(f32_ty, c0, f32_val(0x1p+120)); // bias 2**(127-7) + Value d1 = fmul(f32_ty, c1, f32_val(0x1p+120)); + d0 = bitcast(d0, i32_ty); + d1 = bitcast(d1, i32_ty); + + Value out0 = or_(i32_ty, lshr(i32_ty, d0, i32_val(16)), d1); + Value sign0 = and_(i32_ty, a0, i32_val(0x80008000)); + out0 = or_(i32_ty, out0, sign0); + + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + out0 = bitcast(out0, bf16x2VecTy); + return {extract_element(bf16_ty, out0, i32_val(0)), + extract_element(bf16_ty, out0, i32_val(1))}; +} + +/****************************************************************************/ + // fp8e4m3fnuz to bf16 static SmallVector Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, @@ -880,6 +927,7 @@ struct FpToFpOpConversion // F8 -> BF16 {{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16}, {{F8E5M2FNUZTyID, BF16TyID, undefRounding}, Fp8E5M2FNUZ_to_Bf16}, + {{F8E4M3FNTyID, BF16TyID, undefRounding}, Fp8E4M3FN_to_Bf16}, {{F8E4M3FNUZTyID, BF16TyID, undefRounding}, Fp8E4M3FNUZ_to_Bf16}, // BF16 -> F8 {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE}, Bf16_to_Fp8E5M2}, @@ -887,7 +935,6 @@ struct FpToFpOpConversion Bf16_to_Fp8E5M2FNUZ}, {{BF16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, Bf16_to_Fp8E4M3FNUZ}, - // F32 <-> F8 {{F32TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FNUZ}, @@ -936,9 +983,9 @@ struct FpToFpOpConversion } return outVals; } - size_t numElements = 4; - if (srcElementType.isFloat8E4M3FNUZ() || + if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || + srcElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E4M3FNUZ() || srcElementType.isFloat8E5M2FNUZ() || dstElementType.isFloat8E5M2FNUZ()) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp index e06fa664e893..b83707ee145f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp @@ -5,8 +5,7 @@ #include "llvm/Support/raw_ostream.h" #include // unify to llvm::raw_string_ostream ? -namespace mlir { -namespace triton { +namespace mlir::triton { GCNInstr::Operand * GCNBuilder::newOperand(mlir::Value value, StringRef constraint, @@ -187,5 +186,4 @@ GCNInstrExecution::getArgList() const { return args; } -} // namespace triton -} // namespace mlir +} // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index f7dc8755faa3..825697e0e911 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,6 +1,7 @@ #include "BufferOpsEmitter.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -39,15 +40,27 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto sizePerThread = triton::gpu::getSizePerThread(layout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto threadOrder = triton::gpu::getThreadOrder(layout); + SmallVector warpOrder(rank); + if (auto enc = dyn_cast(layout)) { + warpOrder = + triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1); + } else { + warpOrder = triton::gpu::getWarpOrder(layout); + } + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); + // TODO: [DOT LL] + // The delinearize function is not entirely correct for certain layouts, + // such as wgmma. The correct approach is to convert a legacy layout to its + // corresponding linear layout and use the linear layout's + // getFreeVariableMasks to identify redundant elements. SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension if (shape[dim] >= shapePerCTATile[dim]) @@ -165,7 +178,7 @@ struct LoadStoreConversionBase { // Get alignment from the pointer. Since this is a scalar pointer // we should not take the pointer contiguity to consider alignment auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); - auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxMultipleBytes = axisInfo->getDivisibility(0); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); auto elemNumBytes = std::max(elemNumBits / 8, 1); auto align = std::max(maxMultipleBytes / elemNumBytes, 1); @@ -276,6 +289,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; @@ -286,8 +300,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, assert(wordNElems * nWords * numVecs == numElems); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); - Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); + Value ptr = ptrElems[vecStart]; Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); // If we need to mask the loaded value with other elements @@ -309,6 +322,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -391,6 +407,10 @@ struct BufferLoadOpConversion Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + const int numVecs = numElems / vec; + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -457,7 +477,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, SmallVector> asmArgs; Value elem = valueElems[vecStart]; - Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); + Value ptr = ptrElems[vecStart]; // Create the store val Value storeVal = packElementRangeIntoVector( @@ -694,6 +714,32 @@ struct AtomicCASOpConversion } }; +bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { + return isaFamily == triton::AMD::ISAFamily::CDNA1 || + isaFamily == triton::AMD::ISAFamily::CDNA2 || + isaFamily == triton::AMD::ISAFamily::CDNA3; +} + +Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { + assert(val.getType().isInteger(32)); + auto loc = val.getLoc(); + Value old = i32_val(0); + int rowMask = 0b1111; // enable all rows + int bankMask = 0b1111; // enable all banks + bool boundCtrl = false; + auto dppMovOp = rewriter.create( + loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); + return dppMovOp.getResult(); +} + +Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane +} + +Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane +} + struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -765,10 +811,36 @@ struct AtomicRMWOpConversion // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; + Type packF16Ty = vec_ty(valueElemTy, 2); + + // In the case of unpaired f16 elements utilize dpp instructions to + // accelerate atomics. Here is an algorithm of lowering + // tt::atomicRmwOp(%ptr, %val, %mask): + // 0. Group thread by pairs. Master thread is (tid % 2 == 0); + // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so + // all the masters recieve value from secondary threads; + // 2. Take into account parity in the %mask value, build control flow + // structures according to it; + // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; + // 4. All the threads send result of generated operation to (tid + 1) thread + // via dppUpdateOp shl, so all secondary thread also recieve their + // result. + // + // This approach enables us to use half the active threads committing atomic + // requests to avoid generating of code providing unified access to f16 + // element and reduce contantion. + bool useDppForPackedF16 = false; // tensor if (tensorTy) { auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); + unsigned availableVecSize = isF16Ty ? 2 : 1; + vec = std::min(vec, availableVecSize); + // Force F16 packing in the case it's not comming in as packed, but the + // ISA can support packed atomic instructions. + useDppForPackedF16 = + supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && + vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD; // mask numElems = tensorTy.getNumElements(); } @@ -776,20 +848,49 @@ struct AtomicRMWOpConversion auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + if (useDppForPackedF16) + mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; + retType = useDppForPackedF16 ? packF16Ty : retType; SmallVector resultVals(elemsPerThread); - const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + Value operand; + if (useDppForPackedF16) { + // Move %val to left neighbour to proceed packed atomic further. + Value packedVal = null(packF16Ty); + packedVal = + insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); + // Pack to i32 type to simplify transaction + packedVal = bitcast(packedVal, i32_ty); + Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); + // Unpack results back + Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); + operand = undef(packF16Ty); + operand = + insert_element(packF16Ty, operand, valElements[i], i32_val(0)); + operand = insert_element( + packF16Ty, operand, + extract_element(valueElemTy, unpackedDppRes, i32_val(0)), + i32_val(1)); + } else if (vec == 1) { + operand = valElements[i]; + } else { + operand = undef(vecTy); + for (size_t ii = 0; ii < vec; ++ii) + operand = + insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); + } + Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -806,25 +907,11 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = rewriter - .create( - loc, *maybeKind, rmwPtr, valElements[i], - atomicMemOrdering, StringRef("agent")) - .getResult(); - - // NV for the f16v2 case generates one packed instruction. We have to - // create two separate instructions since LLVM::AtomicRMWOp doesn't - // support this. Can be optimized out with rocdl.raw.buffer.atomic. - if (f16v2) { - Value atom2 = - rewriter - .create( - loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], - atomicMemOrdering, StringRef("agent")) - .getResult(); - auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); - atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); - } + Value atom = + rewriter + .create(loc, *maybeKind, rmwPtr, operand, + atomicMemOrdering, StringRef("agent")) + .getResult(); if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = @@ -837,10 +924,25 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); + if (useDppForPackedF16) { + // Return packed to i32 result after atomic operation back from master + // lane. + auto packedRet = bitcast(retVal, i32_ty); + Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); + // Unpack results back + Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); + retVal = insert_element( + packF16Ty, retVal, + extract_element(valueElemTy, unpackedDppRes, i32_val(1)), + i32_val(1)); + resultVals[i] = + extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); + } else { + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); + } } } else { if (!atomicNeedsSharedMemory(op.getResult())) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index db3223f119da..4a0a7fed22b0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -96,7 +96,8 @@ class OptimizeAMDLDSUsage auto dstEnc = dstType.getEncoding(); auto ctx = srcEnc.getContext(); - auto rank = srcType.getShape().size(); + auto rank = srcType.getRank(); + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); auto warpSize = triton::gpu::getWarpSize(srcEnc); @@ -109,11 +110,20 @@ class OptimizeAMDLDSUsage // Create a list of temporary layouts SmallVector elemsPerThread(rank, 1); SmallVector threadsPerWarp(rank, 1); - threadsPerWarp[rank - 1] = warpSize / 8; - threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + + // Special case for rank == 1 + if (rank == 1) { + threadsPerWarp[0] = warpSize; + } else { + assert(rank > 1); + threadsPerWarp[rank - 1] = warpSize / 8; + threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + } + auto layoutCTA = triton::gpu::getCTALayout(srcEnc); auto order = triton::gpu::getOrder(srcEnc); SmallVector dummyWarpsPerCTA(rank, 1); + auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, layoutCTA); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp index dfa8e06e247c..fb0bfb656ef4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -68,9 +68,13 @@ Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), src.getKWidth()); } - if (auto src = dyn_cast(layout)) + if (auto src = dyn_cast(layout)) { + // TODO: think of a way to construct slice layouts based on warpsPerCTA + // argument + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); return triton::gpu::SliceEncodingAttr::get( - ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA)); + ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); + } assert("Encountered unsupported layout"); return Attribute(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h index 2bd2a977fa7d..6b902b303c81 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H -#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -11,27 +11,24 @@ int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op); std::vector> factorizePowerOf2(int n, int rank); -/** - * @brief Copy given layout with different warpsPerCTA parameter - * @param layout original layout - * @param warpsPerCTA new warpsPerCTA - * @return create layout - */ +/// Copy given layout with different warpsPerCTA parameter +/// +/// \param layout original layout +/// \param warpsPerCTA new warpsPerCTA +/// \returns create layout Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA); -/** - * Creates two chained convert layout operations - * - * %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation - * -> - * %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first - * %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second - * - * @param builder - * @param cvtOp original operation - * @param tmpLayout - * @return pair of created operations - */ +/// Creates two chained convert layout operations +/// +/// %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation +/// -> +/// %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first +/// %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second +/// +/// \param builder +/// \param cvtOp original operation +/// \param tmpLayout +/// \returns pair of created operations std::pair createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, Attribute tmpLayout); @@ -47,4 +44,4 @@ estimateResourcesForReplacement(OpBuilder builder, } // namespace mlir::triton::AMD -#endif // TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 764f31a610e1..b217fc495643 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONAMDPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H -#define TRITON_CONVERSION_TRITONAMDPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ #include "TargetInfo.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + } // namespace mlir::triton::AMD -#endif +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 9bed87961966..d93f2ca6c6ec 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,108 +1,430 @@ +#include "SchedInstructions.h" #include "TritonAMDGPUToLLVM/Passes.h" - +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" -#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { -#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS -#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton +#undef DEBUG_TYPE +#define DEBUG_TYPE "lower-insert-instruction-sched-hints" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + using namespace mlir; -namespace { +// TODO: The following passes/algorithms are applicable only for a single +// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block. +// Note, we need to relax this assumption in the future and extend the current +// implementation. -// The bitmask that encodes kinds of the instructions from AMD ISA. -// The bitmask is used for providing instruction scheduling hints. -enum InstructionKindMask { - NONE = 0x0000000, - ALL_ALU = 0x00000001, - VALU = 0x00000002, - SALU = 0x00000004, - MFMA = 0x00000008, - ALL_VMEM = 0x00000010, - VMEM_READ = 0x00000020, - VMEM_WRITE = 0x00000040, - ALL_DS = 0x00000080, - DS_READ = 0x00000100, - DS_WRITE = 0x00000200 -}; +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + const bool isBufferLoadOp = + std::is_same_v; + if (opIdxAttr.getValue() == 0) { + schedHint.setNumGlobalLoadsAAttr(counterAttr); + schedHint.setIsBufferLoadsAEnabled(isBufferLoadOp); + } else { + schedHint.setNumGlobalLoadsBAttr(counterAttr); + schedHint.setIsBufferLoadsBEnabled(isBufferLoadOp); + } + } + }); +} +template void setNumGeneratedGlobalLoads(triton::amdgpu::BufferLoadOp op, + size_t globalLoadsCount, Type type); +template void setNumGeneratedGlobalLoads(triton::LoadOp op, + size_t globalLoadsCount, Type type); + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + } + }); +} + +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { + triton::DotOp dotOp = nullptr; + size_t dotCounter = 0; + forOp->walk( + [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); + + return (dotCounter == 1) ? dotOp : nullptr; +} +} // namespace mlir::triton + +namespace { // Create an intrinsic to control how different instruction kinds should // interleave for better ILP. void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, - InstructionKindMask maskValue, int sizeValue, - int groupIdValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.group.barrier"; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - Value size = - LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); - Value groupId = LLVM::createConstantI32(loc, rewriter, - static_cast(groupIdValue)); - - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{}, - ValueRange{mask, size, groupId}); + mlir::amdgpu::sched_barrier_opt_enum maskValue, + int sizeValue, int groupIdValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + IntegerAttr size = + rewriter.getI32IntegerAttr(static_cast(sizeValue)); + IntegerAttr groupId = + rewriter.getI32IntegerAttr(static_cast(groupIdValue)); + rewriter.create(loc, mask, size, groupId); } // Insert intrinsic that controls the types of instructions that may be -// allowed to cross the intrinsic during instruction scheduling +// allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, - int64_t maskValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.barrier"; - LLVM::FastmathFlagsAttr defaultFlags{}; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{mask}); + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); } // Insert an experimental intrinsic for instruction group level parallelism. // The intrinsic takes a value that specifies the strategy. Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.iglp.opt"; - LLVM::FastmathFlagsAttr defaultFlags{}; - Value iglpValue = - LLVM::createConstantI32(loc, rewriter, static_cast(value)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{iglpValue}); + IntegerAttr iglpValue = + rewriter.getI32IntegerAttr(static_cast(value)); + return rewriter.create(loc, iglpValue); +} + +// The following structs represent in-source database regarding a target +// machine. It provides instructions execution and issue cycles needed for +// scheduling. +struct MachineDescr { + virtual ~MachineDescr() = default; + virtual uint32_t getDsReadIssueCycle(uint32_t instrWidth) = 0; + virtual FailureOr getMmaExecCycle(llvm::ArrayRef dims) = 0; + virtual uint32_t getMmaIssueCycle() = 0; + virtual uint32_t getNumLdsDataPaths() = 0; + static std::unique_ptr get(StringRef arch); +}; + +template struct MachineDescrImpl : MachineDescr { + uint32_t getDsReadIssueCycle(uint32_t instrWidth) final { + return instrWidth == 16 ? 8 : 4; + } + + FailureOr getMmaExecCycle(llvm::ArrayRef dims) final { + if (dims.size() != 3) + return failure(); + auto it = + Derived::mmaTable.find(std::make_tuple(dims[0], dims[1], dims[2])); + if (it != Derived::mmaTable.end()) + return it->second; + return failure(); + } + + uint32_t getMmaIssueCycle() final { return Derived::mmaIssueCycle; }; + uint32_t getNumLdsDataPaths() final { return Derived::numLdsDataPaths; } + + using MmaTable = + llvm::DenseMap, uint32_t>; +}; + +struct CDNA2Kind : public MachineDescrImpl { + static const inline MmaTable mmaTable{{{32, 32, 8}, 64}, {{16, 16, 16}, 32}}; + static const inline uint32_t mmaIssueCycle{4}; + static const inline uint32_t numLdsDataPaths{2}; +}; + +struct CDNA3Kind : public MachineDescrImpl { + static const inline MmaTable mmaTable{{{32, 32, 8}, 32}, {{16, 16, 16}, 16}}; + static const inline uint32_t mmaIssueCycle{4}; + static const inline uint32_t numLdsDataPaths{2}; +}; + +std::unique_ptr MachineDescr::get(StringRef arch) { + AMD::ISAFamily family = AMD::deduceISAFamily(arch); + switch (family) { + case AMD::ISAFamily::CDNA3: { + return std::make_unique>(); + } + case AMD::ISAFamily::CDNA2: { + return std::make_unique>(); + } + default: { + return nullptr; + } + } + return nullptr; } struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) - : OpRewritePattern(ctx) { + InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch, + int32_t numStages, std::string variant) + : OpRewritePattern(ctx), numStages(numStages) { + + this->machineDescr = MachineDescr::get(arch); std::transform(variant.begin(), variant.end(), variant.begin(), [](unsigned char c) { return std::tolower(c); }); - this->schedulingType = llvm::StringSwitch(variant) - .Case("default", SchedulingType::NONE) - .Case("iglp0", SchedulingType::IGLP0) - .Case("iglp1", SchedulingType::IGLP1) - .Default(SchedulingType::UNKNOWN); + this->schedulingType = + llvm::StringSwitch(variant) + .Case("none", SchedulingType::NONE) + .Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0) + .Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1) + .Case("local-prefetch", SchedulingType::LOCAL_PREFETCH) + .Default(SchedulingType::UNKNOWN); + + if (this->numStages < 2) { + this->schedulingType = SchedulingType::NONE; + LDBG("ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2"); + } } - enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + enum class SchedulingType : uint32_t { + NONE = 0, + LLVM_IGLP_0, + LLVM_IGLP_1, + LOCAL_PREFETCH, + UNKNOWN + }; + + // The following is inspired by ROCm Composable Kernel library's V3 pipelining + // (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBM to registers) data prefetching. + void createLocalPrefetchSchedule( + PatternRewriter &rewriter, Location loc, + triton::amdgpu::InstructionSchedHint schedHint) const { + + if (!(schedHint.getIsBufferLoadsAEnabled() && + schedHint.getIsBufferLoadsBEnabled())) { + LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` " + "instructions"); + return; + } + + if (!machineDescr) { + schedHint.emitError("unknown target architecture detected"); + return; + } + + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + if (numBufferLoadInstA == 0) { + schedHint.emitError("buffer load count for tile A must be initialized"); + return; + } + + if (numBufferLoadInstB == 0) { + schedHint.emitError("buffer load count for tile B must be initialized"); + return; + } + + const uint32_t numMmaInst = schedHint.getNumMMAs().getValue(); + + auto mmaType = cast(schedHint.getNumMMAs().getType()); + auto maybeMmaExecCycle = machineDescr->getMmaExecCycle(mmaType.getShape()); + if (llvm::failed(maybeMmaExecCycle)) { + schedHint.emitError("unknown mma instruction type"); + return; + } + const uint32_t mmaExecCycle = maybeMmaExecCycle.value(); + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = + machineDescr->getDsReadIssueCycle(dsReadsAType.getShape()[0]); + const uint32_t dsReadBIssueCycle = + machineDescr->getDsReadIssueCycle(dsReadsBType.getShape()[0]); + + const uint32_t mmaIssueCycle = this->machineDescr->getMmaIssueCycle(); + const uint32_t numLdsDataPaths = this->machineDescr->getNumLdsDataPaths(); + + const auto dsReadAMmaRate = (mmaExecCycle - mmaIssueCycle + + numLdsDataPaths * dsReadAIssueCycle - 1) / + (numLdsDataPaths * dsReadAIssueCycle); + const auto dsReadBMmaRate = (mmaExecCycle - mmaIssueCycle + + numLdsDataPaths * dsReadBIssueCycle - 1) / + (numLdsDataPaths * dsReadBIssueCycle); + + const auto numDsreadAMma = + (numDsReadInstA + dsReadAMmaRate - 1) / dsReadAMmaRate; + const auto numDsreadBMma = + (numDsReadInstB + dsReadBMmaRate - 1) / dsReadBMmaRate; + + // stage 1 + const auto numMmaStage1 = numMmaInst - (numDsreadAMma + numDsreadBMma); + const auto numMmaPerIssue = + numMmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMmaPerIssue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMmaPerIssue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMmaRate) >= dsReadAMmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadAMmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstA - (numDsreadAMma - 1) * dsReadAMmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMmaRate) >= dsReadBMmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadBMmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstB - (numDsreadBMma - 1) * dsReadBMmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + // The AMDGPU compiler backend can fold consecutive `ds_read/ds_write` + // instructions into wider variants as a part of its load/store optimization + // during the instruction selection pass. If it happens, then it means that + // we are overestimated these types of instructions at the current level of + // the IR. In this scenario, the inserted `sched.group.barriers` will result + // in "fooling" the scheduling solver which can mess up the final assembly. + // To avoid this, we switch off the backend load/store folding optimization + // which is going to prevent instructions folding. In this case, the + // instruction widths of `ds_read/ds_write` instructions are going to match + // their LLVM representations. This is implemented as follows. + + // TODO: The current implementation disables `ds_read/ds_write` folding for + // all basic blocks in the currently processed function. We should try to + // avoid it. The compiler backend team proposed to play we the load/store + // alignment values within the currently processed basic block as an + // alternative solution. + auto funcOp = schedHint->getParentOfType(); + MLIRContext *ctx = schedHint->getContext(); + llvm::SmallVector targetFeatures; + if (auto attr = funcOp.getTargetFeatures()) { + llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures)); + } + targetFeatures.push_back(str_attr("-load-store-opt")); + funcOp.setTargetFeaturesAttr( + ::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures)); + } LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { + if (this->schedulingType == SchedulingType::NONE) { + rewriter.eraseOp(instructionSchedHint); + return success(); + } if (this->schedulingType == SchedulingType::UNKNOWN) { - llvm::dbgs() - << "[" << getDebugName() << "]: " - << "unknown instruction scheduling variant has been provided\n"; - return mlir::failure(); + instructionSchedHint.emitError( + "unknown instruction scheduling variant has been provided"); + return failure(); } // The switch controls whether instructions are allowed to cross the basic @@ -110,48 +432,56 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::IGLP0 || - schedulingType == SchedulingType::IGLP1); + !(schedulingType == SchedulingType::NONE || + schedulingType == SchedulingType::LLVM_IGLP_0 || + schedulingType == SchedulingType::LLVM_IGLP_1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { rewriter.setInsertionPointToStart(block); - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); } rewriter.setInsertionPoint(block, std::prev(block->end())); switch (schedulingType) { - case SchedulingType::IGLP0: - [[fallthrough]]; - case SchedulingType::IGLP1: { + case SchedulingType::LLVM_IGLP_0: + case SchedulingType::LLVM_IGLP_1: createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; - } + case SchedulingType::LOCAL_PREFETCH: + createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint); + break; case SchedulingType::NONE: - [[fallthrough]]; - default: { + default: break; } - } if (limitSchedulingRange) - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); rewriter.eraseOp(instructionSchedHint); - return mlir::success(); + return success(); } private: + int32_t numStages; SchedulingType schedulingType; + std::unique_ptr machineDescr; }; -struct LowerInstructionSchedHints - : public triton::impl::LowerInstructionSchedHintsBase< - LowerInstructionSchedHints> { +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { - explicit LowerInstructionSchedHints(std::string variant) { - this->variant = variant; + explicit TritonAMDGPULowerInstructionSchedHints(StringRef arch, + int32_t numStages, + StringRef variant) { + this->arch = std::move(arch.str()); + this->numStages = numStages; + this->variant = std::move(variant.str()); } void runOnOperation() override { @@ -161,29 +491,39 @@ struct LowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(ctx); - patterns.add(ctx, this->variant); + + patterns.add(ctx, this->arch, + this->numStages, this->variant); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); } } }; -struct InsertInstructionSchedHints - : public triton::impl::InsertInstructionSchedHintsBase< - InsertInstructionSchedHints> { +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod->walk([ctx](triton::DotOp dot) { - if (dyn_cast(dot->getParentOp())) { - mlir::OpBuilder rewriter(ctx); - rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + mod.walk([this, ctx](scf::ForOp forOp) { + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (auto dotOp = getSingleDotOpIfExists(forOp)) { + OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dotOp); + rewriter.create(dotOp->getLoc()); } }); } @@ -192,12 +532,15 @@ struct InsertInstructionSchedHints namespace mlir::triton { std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, + int32_t numStages, + StringRef variant) { + return std::make_unique( + arch, numStages, variant); } std::unique_ptr> -createInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass() { + return std::make_unique(); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 000000000000..b1836026032a --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,26 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ + +#include "mlir/IR/Types.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp); +} // namespace mlir::triton + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index c96ddbbe8961..9a0098790057 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -84,6 +85,15 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + // AMD does not support stmatrix + return false; +} + void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { llvm::report_fatal_error("AMDGPU does not support stmatrix"); @@ -103,22 +113,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -126,11 +136,182 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } +// Cast and sext values into specific-length int to meet the requirements of +// instructions like UpdateDpp or readlane if necessary. +static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, + Value &val, Type fromType, + unsigned toBits) { + unsigned originalBits = fromType.getIntOrFloatBitWidth(); + Type toType = fromType; + + if (!fromType.isIntOrIndex()) { + val = bitcast(val, int_ty(originalBits)); + toType = int_ty(originalBits); + } + + if (originalBits < toBits) { + val = sext(int_ty(toBits), val); + toType = int_ty(toBits); + } + + return toType; +} + +// Trunc the value to specific length and then cast it to given type if +// necessary. This function is typically used in conjunction with +// castToAndSExtInt. +static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, + Value val, Type valType, + unsigned fromBits) { + unsigned originalBits = valType.getIntOrFloatBitWidth(); + Value toVal = val; + + if (originalBits < fromBits) { + toVal = trunc(int_ty(originalBits), toVal); + } + + if (!valType.isIntOrIndex()) { + toVal = bitcast(toVal, valType); + } + + return toVal; +} + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - return false; + if (numLaneToReduce != 64) + return false; + + if (auto family = getISAFamily(); + family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { + return false; + } + + Operation *reduxOp = op.getSingleCombiner(); + if (!reduxOp) + return false; + + auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, + uint32_t dppCtrl, int rowMask, + int bankMask) -> Value { + // DPP has limited support for data types, so here we need to + // cast non-integer types or integer types shorter than 32 bits + // to int32, except for fp32. + Type actualType = valType; + if (!valType.isF32()) { + actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); + } + + Value dppResult = + rewriter + .create(loc, actualType, src, src, + rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(true)) + .getRes(); + + if (!valType.isF32()) { + src = truncAndCastFromInt(rewriter, loc, src, valType, 32); + dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); + } + + IRMapping mapping; + mapping.map(reduxOp->getOperand(0), src); + mapping.map(reduxOp->getOperand(1), dppResult); + return rewriter.clone(*reduxOp, mapping)->getResult(0); + }; + + for (int i = 0; i < acc.size(); i++) { + Value buf; + auto valType = acc[i].getType(); + + // Here's the implementation of full-wavefront reduction using dpp. + // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + // + // Each step has a v_mov_dpp instruction following the redux op. In + // some cases, the lower-level compiler could merge them into single + // instruction. For example, v_mov_dpp + max => v_max_dpp. + // + // For gfx9, we have 64 threads per warp. These 64 threads are arranged + // into 4 rows, with each row being 16 threads. Each 16 threads are arranged + // further into 4 banks, with each bank being 4 threads. Overall it's in a + // (row, bank, thread) structure. When shuffling, we use row/bank mask to + // indicate which row/bank to participate. Then modifier like row_shr and + // row_bcast means exact data movement schemes. In the following + // instructions, taking row 0 as an example: + // + // Step 1: Right shift for 8 lanes. + // lane 8-15 = redux(lane 0-7, lane 8-15) + // + // Step 2: Right shift for 4 lanes. + // lane 12-15 = redux(lane 8-11, lane 12-15) + // + // Step 3: Right shift for 2 lanes. + // lane 14-15 = redux(lane 12-13, lane 14-15) + // + // Step 4: Right shift for 1 lane. + // lane 15 = redux(lane 14, lane 15) + // + // Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + // lane 16-31 = redux(lane 15, lane 16-31) + // + // Step 6: Broadcast lane 31 to lane 32-63. + // lane 32-63 = redux(lane 31, lane 32-63) + // + // Now the reduction result is stored in lane 63. + // + // Step 7: Read the reduction result from lane 63 and broadcast with + // readlane. + + const int allRows = 0xf; + const int allBanks = 0xf; + + const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); + + // row_shr:8 + buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:4 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:2 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:1 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, + allRows, allBanks); + + // row_bcast:15 row_mask:0xa + buf = createDppReduxOpWithBoundCtrl( + valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); + + // row_bcast:31 + buf = createDppReduxOpWithBoundCtrl(valType, buf, + static_cast(DppCtrl::BCAST31), + allRows, allBanks); + + // Similarly, we need to cast data types for readlane instruction. + Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); + + // Get reduction result from lane 63 + std::string intrinsic = "llvm.amdgcn.readlane"; + Value result = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, + ValueRange{buf, i32_val(63)}) + ->getResult(0); + + result = truncAndCastFromInt(rewriter, loc, result, valType, 16); + + acc[i] = result; + } + + return true; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, @@ -245,4 +426,10 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, int TargetInfo::getSharedAddressSpace() const { return 3; } +bool TargetInfo::supportVectorizedAtomics() const { + // Note: not currently tested or used, but AMD generally supports vectorized + // atomics. + return true; +} + } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index eabb5d6715ac..31fa09e5198e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOAMD_H -#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOAMD_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ #include "TritonAMDGPUToLLVM/TargetUtils.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" @@ -27,6 +27,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type elemTy, Value pred) const override; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const override; @@ -58,6 +63,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef file, StringRef func, int line) const override; int getSharedAddressSpace() const override; + bool supportVectorizedAtomics() const override; + private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, RewriterBase &rewriter, bool useStdErr) const; @@ -66,4 +73,4 @@ class TargetInfo : public mlir::triton::TargetInfoBase { }; } // namespace mlir::triton::AMD -#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOAMD_H +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index aa71c92666f7..31df3a8a60d2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -9,7 +10,6 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -24,12 +24,10 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -namespace mlir { -namespace triton { +namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM #include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir +} // namespace mlir::triton using namespace mlir; @@ -39,7 +37,6 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { public: explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalDialect(); @@ -72,8 +69,9 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -97,9 +95,9 @@ struct ConvertTritonAMDGPUToLLVM int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); // Hack: WSMaterialization may have changed the effective number of warps, - // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // in a way that isn't reflected in ttg.num-warps. If so, we have to // respect that here. - if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + if (Attribute attr = mod->getAttr("ttg.num-warp-groups-per-cta")) { numWarps *= cast(attr).getInt(); } @@ -193,8 +191,14 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, commonBenefit); + populatePatterns7(mlir::triton::populateGatherOpToLLVMPatterns, + commonBenefit); + + mlir::triton::BackendCallbacks callbacks; + callbacks.localStoreOpConversion = storeOpConversionCallback; + + mlir::triton::populateMemoryOpToLLVMPattern( + typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, @@ -207,6 +211,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, patterns, AMDBenefit); + mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns @@ -223,6 +229,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } @@ -249,15 +256,13 @@ struct ConvertTritonAMDGPUToLLVM } }; -} // anonymous namespace +} // namespace -namespace mlir { -namespace triton { +namespace mlir::triton { std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) { return std::make_unique(targetArch, ftz); } -} // namespace triton -} // namespace mlir +} // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000000..4126f4cc4ad3 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,141 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fpType = op.getFpType(); + bool isPacked = fpType == ScaleDotElemType::E2M1; + if (!(isPacked || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); + + Location loc = op.getLoc(); + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); + + // When we lower scaled dot op, we made sure to distribute K only on one + // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values + // along the K dimension. So in total each thread should read 32x main + // element values. + if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32)) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); + + auto dotEncoding = + cast(op.getSrc().getType().getEncoding()); + auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); + if (!mfmaEncoding) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); + LDBG("mfma: " << mfmaEncoding); + + int mDim = mfmaEncoding.getMDim(); + if (mDim != 32 && mDim != 16) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); + + int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + Value warpSize = i32_val(numThreads); + Value tid = tid_val(); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + if (isPacked) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + + // Given that MFMA layout for the A tensor arranges thread in a column-major + // manner, for the current tid, it's at row (tid % mDim). When we set up + // blocked layout for the A scale tensor, we made sure that it has a + // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values + // for the current thread starts at ((tid % mDim) * (64 / mDim)). + Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + + if (mDim == 32) { + // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we + // tile, the same warp owns the whole K dim. Inside a warp, each thread + // only holds 4 consecutive elements along K--a 1x4 vector. We need to + // tile the warp 4 times to cover 32 values along K. So for a thread, the + // first 4 1x4 vectors it holds shares the first scale value at row (tid % + // mDim). the second 4 1x4 vectors shares the second scale value at row + // (tid % mDim); and so forth. + std::array scaleThreads = {offset, add(offset, i32_val(1))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + std::array si = { + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + } + } + } else { + assert(mDim == 16); + // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we + // need to tile the warp 2 times to cover 32 valeus. So for a thread, the + // first 2 1x4 vectors shares the first scale value at row (tid % mDim). + std::array scaleThreads = {offset, add(offset, i32_val(1)), + add(offset, i32_val(2)), + add(offset, i32_val(3))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + } + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 542b1ecbb7fb..0bd401f1993a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,6 +8,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::AMD::DppCtrl; +using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -71,8 +73,9 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, - Value i, int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -84,7 +87,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, if (bits < 32) val = sext(i32_ty, val); - val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); + val = + shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -98,8 +102,10 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); - val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); + val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, + clamp); + val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, + clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -134,13 +140,83 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else { - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); + } else if (strideInt == 16) { + Value offset = i32_val(0x401F); return rewriter.create(loc, valType, val, offset); + } else { + if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { + // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback + // to ds_swizzle for other archs. + // + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); + } + + auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, + uint32_t dppCtrl, uint32_t rowMask, + uint32_t bankMask) { + return rewriter.create( + loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); + }; + + const int allRows = 0xf; + const int allBanks = 0xf; + + switch (strideInt) { + case 1: { + // quad_perm: 1, 0, 3, 2 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {1, 0, 3, 2}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 2: { + // quad_perm: 2, 3, 0, 1 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {2, 3, 0, 1}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 4: { + // row_shr:4 bank_mask: 0xa + auto ret = createDppOpWithoutBoundCtrl( + val, val, 4 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xa) + .getRes(); + + // row_shl:4 bank_mask: 0x5 + return createDppOpWithoutBoundCtrl( + ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x5); + } + case 8: { + // row_shr:8 bank_mask: 0xc + auto ret = createDppOpWithoutBoundCtrl( + val, val, 8 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xc) + .getRes(); + + // row_shl:8 bank_mask: 0x3 + return createDppOpWithoutBoundCtrl( + ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x3); + } + default: + assert(false && + "bfly shfl with stride >= 16 should not be handled by dpp."); + } } break; case ShflKind::up: { @@ -158,22 +234,27 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, - i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, - i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleIdx(loc, rewriter, val, i32_val(i)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { - return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, + i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 123234fd4824..cba2db5a896b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -1,13 +1,15 @@ -#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H -#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -19,10 +21,18 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); @@ -39,4 +49,4 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, triton::CacheModifier cm = triton::CacheModifier::NONE); } // namespace mlir::LLVM::AMD -#endif +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index a26a18ed96bc..f9cd0f14382e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -5,6 +5,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) { return 0; } -SmallVector warpsPerTile(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) return {(unsigned)numWarps, 1, 1}; - auto filter = [&dotOp](Operation *op) { + auto filter = [dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; ForwardSliceOptions fwdOpt; @@ -55,43 +56,118 @@ SmallVector warpsPerTile(tt::DotOp dotOp, bwdOpt.filter = filter; auto slices = getSlice(dotOp, bwdOpt, fwdOpt); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (op->hasTrait() && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; - SmallVector ret = {1, 1}; + SmallVector ret = {1, 1}; do { if (ret[0] * ret[1] >= numWarps) break; - if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { ret[0] *= 2; - } else + } else { ret[1] *= 2; + } } else { ret[1] *= 2; } } while (true); - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + if (ret[1] * shapePerWarp.second > tensorShape[1]) { return {ret[1], ret[0]}; } return ret; } -SmallVector -warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); } -SmallVector -warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { - return warpsPerTile(dotOp, shape, numWarps, - {ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0], - ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]}); +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to chose a MFMA with matching M/N dim. +FailureOr chooseMfmaInstruction(RankedTensorType cType, + Type aElemType, Type bElemType, + int inputKSize, int mfmaVersion, + int enforcedNonKDim) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + if (minSize < 16) { + if (M < 16 && N >= 64) { + mDim = 4; + nDim = 64; + } else if (M >= 64 && N < 16) { + mDim = 64; + nDim = 4; + } else { + assert(inputKSize >= 64 && + "k should be at least 64 to use this layout"); + mDim = 4; + nDim = 4; + } + } + } + assert(mDim != 0 && nDim != 0); + + auto maybeMfmaInsn = + MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion); + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaInsn->getKDim(); + assert(kDim != 0); + assert(M % mDim == 0 && N % nDim == 0); + assert(inputKSize % kDim == 0); + return maybeMfmaInsn; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim) { + RankedTensorType aType = dot.getA().getType(); + return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), + aType.getShape().back(), mfmaVersion, nonKDim); +} + +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion, + int nonKDim) { + // For scaled dot, we handle it with bf16 emulation for now. + Type bf16Type = Builder(dot.getContext()).getBF16Type(); + return chooseMfmaInstruction( + dot.getC().getType(), /*aElemType=*/bf16Type, /*bElemType=*/bf16Type, + dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim); } using OperandTypesVector = SmallVector; @@ -189,23 +265,23 @@ OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter, return selectMatrixCoreOperandTypes(dot, applicableTypes); } -/** - * @brief Convert layout and cast element type of a given tensor - * - * If old element type is different from new element type, this function - * creates two new operations: - * 1. %converted_value = layout_convert %value, newEncoding - * 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType - * - * If old element type is same as new element type, this function creates only - * one operation: %converted_value = layout_convert %value, newEncoding - * - * @param rewriter - * @param value original tensor value, which we need to convert and cast - * @param newEncoding new encoding for the tenosr - * @param newElemType new element type for the tensor - * @return converted and optionaly casted tensor value - */ +//===---------------------------------------------------------------------===// +// @brief Convert layout and cast element type of a given tensor +// +// If old element type is different from new element type, this function +// creates two new operations: +// 1. %converted_value = layout_convert %value, newEncoding +// 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType +// +// If old element type is same as new element type, this function creates only +// one operation: %converted_value = layout_convert %value, newEncoding +// +// @param rewriter +// @param value original tensor value, which we need to convert and cast +// @param newEncoding new encoding for the tenosr +// @param newElemType new element type for the tensor +// @return converted and optionaly casted tensor value +//===---------------------------------------------------------------------===// Value convertAndCastTensor(PatternRewriter &rewriter, Value value, Attribute newEncoding, Type newElemType) { assert(newElemType.isIntOrFloat()); @@ -259,15 +335,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value, return castedTensor; } -class BlockedToMFMA : public RewritePattern { +class BlockedToMFMA : public OpRewritePattern { int mfmaVersion; - int enforcedNonKDim; + int nonKDim; int kPack; public: - BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { @@ -285,75 +362,15 @@ class BlockedToMFMA : public RewritePattern { return false; } - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @return MfmaInsn or failure - FailureOr chooseMfmaInstruction(tt::DotOp dot) const { - // number of matrix elements along k dim per one MFMA intruction - unsigned kDim = 0; - auto opType = cast(dot.getA().getType()); - auto dataTypeA = opType.getElementType(); - auto dataTypeB = - cast(dot.getB().getType()).getElementType(); - - auto resType = cast(dot.getD().getType()); - auto resShape = resType.getShape(); - auto rank = resShape.size(); - auto M = resShape[rank - 2]; - auto N = resShape[rank - 1]; - - unsigned mDim = 0; - unsigned nDim = 0; - if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; - } else { - int minSize = std::min(M, N); - if (minSize >= 32) { - mDim = 32; - nDim = 32; - } - if (minSize >= 16 && minSize < 32) { - mDim = 16; - nDim = 16; - } - if (minSize < 16) { - if (M < 16 && N >= 64) { - mDim = 4; - nDim = 64; - } else if (M >= 64 && N < 16) { - mDim = 64; - nDim = 4; - } else { - assert(opType.getShape()[rank - 1] >= 64 && - "k should be at least 64 to use this layout"); - mDim = 4; - nDim = 4; - } - } - } - assert(mDim != 0 && nDim != 0); - - auto maybeMfmaInsn = - MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion); - if (failed(maybeMfmaInsn)) - llvm::report_fatal_error("No match found in MFMA database\n"); - - kDim = maybeMfmaInsn->getKDim(); - assert(kDim != 0); - assert(M % mDim == 0 && N % nDim == 0); - assert(opType.getShape()[rank - 1] % kDim == 0); - return maybeMfmaInsn; - } - - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || !isa(oldRetType.getEncoding())) return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); if (!supportMFMA(dotOp)) return failure(); @@ -362,7 +379,7 @@ class BlockedToMFMA : public RewritePattern { // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands @@ -374,7 +391,7 @@ class BlockedToMFMA : public RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto mfmaInstr = chooseMfmaInstruction(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); auto mDim = mfmaInstr.value().getMDim(); auto nDim = mfmaInstr.value().getNDim(); auto kDim = mfmaInstr.value().getKDim(); @@ -397,7 +414,7 @@ class BlockedToMFMA : public RewritePattern { mfmaAccType = rewriter.getF32Type(); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // Here is a brief explanation of kWidth, kBase, and kDim @@ -456,11 +473,166 @@ class BlockedToMFMA : public RewritePattern { convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); + + return success(); + } +}; + +class ScaledBlockedToMFMA final : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, + int kPack, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + unsigned rank = oldRetType.getRank(); + if (rank == 3) + return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case"); + + TensorValue a = dotOp.getLhs(); + TensorValue b = dotOp.getRhs(); + TensorValue aScale = dotOp.getLhsScale(); + TensorValue bScale = dotOp.getRhsScale(); + if (aScale && bScale) + return rewriter.notifyMatchFailure(dotOp, "NYI: both LHS and RHS scale"); + + ScaleDotElemType aElemType = dotOp.getLhsType(); + ScaleDotElemType bElemType = dotOp.getRhsType(); + auto supportsTypes = [](ScaleDotElemType elemType) { + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2 || + elemType == ScaleDotElemType::BF16; + }; + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) + return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6 operand"); + + MLIRContext *ctx = dotOp.getContext(); + auto moduleOp = dotOp->getParentOfType(); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); + int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + + // Choose a suitable MFMA instruction for this scaled dot op. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + + unsigned mDim = mfmaInstr.value().getMDim(); + unsigned nDim = mfmaInstr.value().getNDim(); + unsigned kDim = mfmaInstr.value().getKDim(); + unsigned kBase = mfmaInstr.value().getKBase(); + + // For mxfp4 A/B tensor, we pack every two values into one int8 value there. + // For such cases, we have different initial kWidth for LHS and RHS, which + // will be "fixed" later by using upcast_mxfp to convert LHS to unpacked + // values. For such packed cases, we cannot support flexible kPack choices + // from the developer--it just does not apply here. So mandate the choice + // here. + bool isAPacked = aElemType == ScaleDotElemType::E2M1; + bool isBPacked = bElemType == ScaleDotElemType::E2M1; + bool isPacked = isAPacked || isBPacked; + unsigned kWdiths[] = {isPacked ? (isAPacked ? 4 : 8) : kBase * kPack, + isPacked ? (isAPacked ? 8 : 4) : kBase * kPack}; + + // For A/B tensor, 32 consecutive elements along K dim share the same scale. + // We'd like to keep the scale values together with the base values in the + // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA + // = 1 along the N/M dimension for the mxfp A/B case. We achieve that by + // setting the M/N dimension as numWarps. + SmallVector mfmaWarpsPerCTA(rank, 1); + mfmaWarpsPerCTA[aScale ? 0 : 1] = numWarps; + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, mfmaWarpsPerCTA, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + auto toMMABf16 = [&](TensorValue v, int idx, + ScaleDotElemType type) -> TensorValue { + auto vType = v.getType(); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), kWdiths[idx]); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + // Don't need to covert int8 holding mxfp4--the upcast_mxfp op can + // take int8 tensor as input. + if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::E2M1) + return v; + + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return cast( + rewriter.create(v.getLoc(), vTypeBf16, v).getResult()); + }; + a = toMMABf16(a, 0, aElemType); + b = toMMABf16(b, 1, bElemType); + + // We need to have "matching" encoding between the main tensor and scale + // tensor to make sure the scale values needed is in the same warp. So we + // adopt the same CTA layout and warps per CTA. The warp dimensions needs to + // match along M/N dimension too. With in a warp, we have 64 threads. We let + // each thread read in one scale value. So we need a threadsPerWarp = + // mDim/nDim along M/N dimension. Note that For MFMA intrinsics, mDim is + // always the same as nDim. And for scaled dot scale tensor, we always have + // K as the innermost dimension. So we have the same threadsPerWarp in the + // below no matter A or B scale. Similarly for warpsPerCTA, the non-K + // dimension is always at index 0. + assert(mDim == nDim); + SmallVector threadsPerWarp = {mDim, numThreads / mDim}; + SmallVector blockWarpsPerCTA(rank, 1); + blockWarpsPerCTA[0] = numWarps; + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout); + + auto upcastMXFP = [&](TensorValue main, TensorValue scale, + ScaleDotElemType elemType) -> Value { + if (!scale) + return main; + + auto newScaleType = RankedTensorType::get( + scale.getType().getShape(), scale.getType().getElementType(), + newScaleEncoding); + auto convOp = rewriter.create(scale.getLoc(), + newScaleType, scale); + + return rewriter.create(dotOp.getLoc(), main, + convOp, elemType); + }; + Value scaledA = upcastMXFP(a, aScale, dotOp.getLhsType()); + Value scaledB = upcastMXFP(b, bScale, dotOp.getRhsType()); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, scaledA, + scaledB, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); return success(); } }; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -566,18 +738,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { }); } -class BlockedToWMMA : public RewritePattern { +class BlockedToWMMA : public OpRewritePattern { int wmmaVersion; public: - BlockedToWMMA(MLIRContext *context, int wmmaVersion) - : RewritePattern(tt::DotOp::getOperationName(), 2, context), - wmmaVersion(wmmaVersion) {} + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(tt::DotOp dotOp, PatternRewriter &rewriter) const override { - auto ctx = op->getContext(); - auto dotOp = cast(op); + auto ctx = dotOp->getContext(); Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -603,7 +774,7 @@ class BlockedToWMMA : public RewritePattern { if (wmmaVersion == 2 && llvm::isa(oldAType) && oldAType.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure(op, "not supported yet"); + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); } // get operand types @@ -612,7 +783,7 @@ class BlockedToWMMA : public RewritePattern { return failure(); // get WMMA encoding for the given number of warps - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); ttg::AMDWmmaEncodingAttr wmmaEnc; @@ -626,7 +797,7 @@ class BlockedToWMMA : public RewritePattern { auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); @@ -653,7 +824,7 @@ class BlockedToWMMA : public RewritePattern { Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; @@ -683,8 +854,8 @@ class TritonAMDGPUAccelerateMatmulPass case ISAFamily::CDNA1: case ISAFamily::CDNA2: case ISAFamily::CDNA3: - patterns.add<::BlockedToMFMA>(context, getMfmaVersion(isaFamily), - matrixInstructionSize, kPack); + patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack); break; case ISAFamily::RDNA3: patterns.add<::BlockedToWMMA>(context, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 414e4a329fdb..aef5886b11d8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -1,12 +1,14 @@ add_triton_library(TritonAMDGPUTransforms AccelerateAMDMatmul.cpp CanonicalizePointers.cpp + ConvertToBufferOps.cpp OptimizeEpilogue.cpp ReorderInstructions.cpp - StreamPipelineV2.cpp + StreamPipeline.cpp MfmaGroup.cpp DEPENDS + TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index f8c497968201..fa7d54e0fbee 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -73,6 +73,17 @@ using namespace mlir; // `%fat_ptr = tt.addptr(%t_ptr, %fatPointers[ptr].offset)` // `%data = tt.load(%fat_ptr)` // +// Please note that `%offset` might be a 32bit or 64bit integer. If +// we can, we would like to use 32 bit integers. This can happen under +// certain conditions: +// +// a) We can determine that the offset cannot overflow. In this case, we can +// downcast the pointer just before emitting the load +// b) We know that the underlying memory size can be expressed as a 32 bit +// value. In this case we can simply start with a 32bit offset and downcast +// if we ever meet 64 bit operations (because we know that the offset can be +// contained in 32 bits) +// class PointerCanonicalizer { public: explicit PointerCanonicalizer(ModuleOp moduleOp) @@ -96,7 +107,7 @@ class PointerCanonicalizer { // Utility copy functions FatPtr copy(Value newBasePtr, Value newOffset) { return FatPtr{newBasePtr, newOffset, canNarrow}; - }; + } FatPtr copyWithBase(Value newOffset) { return FatPtr{basePtr, newOffset, canNarrow}; } @@ -571,12 +582,16 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, bool propagateAtrs = true; if (!isZeroConst(nonUniformOffset)) { Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); + Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); - // If we the incoming offset is 32 bits, then we have to cast to 64 - if (addPtrOffsetType.isInteger(32)) + // Upcast or downcast the offset accordingly + if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) nonUniformOffset = extend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); + else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) + nonUniformOffset = + narrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); newOffset = rewriter.create(curLoc, nonUniformOffset, fatPtrOffset); @@ -958,14 +973,18 @@ LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) { LogicalResult PointerCanonicalizer::rewriteFunction(triton::FuncOp funcOp) { Region ®ion = funcOp.getRegion(); - for (Value arg : region.getArguments()) { + for (auto [idx, arg] : llvm::enumerate(region.getArguments())) { // The pointer argument needs to be a scalar if (!isa(arg.getType())) continue; + int64_t bitness = 64; + if (IntegerAttr pointerRangeAttr = + funcOp.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); rewriter.setInsertionPointToStart(®ion.front()); Value zeroOffset = - rewriter.create(region.getLoc(), 0, 64); + rewriter.create(region.getLoc(), 0, bitness); // Start the rewrite clearFunctionState(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp new file mode 100644 index 000000000000..e66a2feb57fe --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -0,0 +1,273 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/TypeSwitch.h" +#include +#include + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; + +namespace { +bool verifyNonNegativeByAssumption(Value expr, + const DenseSet &assumptions) { + for (Value assume : assumptions) { + LDBG("Assumption:" << assume); + if (auto cmpOp = assume.getDefiningOp()) { + bool isGreaterThan = (cmpOp.getPredicate() == arith::CmpIPredicate::sge || + cmpOp.getPredicate() == arith::CmpIPredicate::sgt); + APInt cst; + if (isGreaterThan && (cmpOp.getLhs() == expr) && + matchPattern(cmpOp.getRhs(), m_ConstantInt(&cst))) { + return cst.isNonNegative(); + } + } + } + return false; +} + +bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { + + // Check if the expression is contained in any assumption + if (verifyNonNegativeByAssumption(expr, assumptions)) { + LDBG("Non negative by assumption"); + return true; + } + + // Recurse if the operation is defined + Operation *op = expr.getDefiningOp(); + if (!op) + return false; + + bool nonNegative = + llvm::TypeSwitch(expr.getDefiningOp()) + .Case([&](auto broadcastOp) { + return verifyNonNegativeExpr(broadcastOp.getSrc(), assumptions); + }) + .Case([&](auto expandOp) { + return verifyNonNegativeExpr(expandOp.getSrc(), assumptions); + }) + .Case([&](auto splatOp) { + return verifyNonNegativeExpr(splatOp.getSrc(), assumptions); + }) + .Case([&](auto makeRangeOp) { + return makeRangeOp.getStart() >= 0 && makeRangeOp.getEnd() >= 0; + }) + .Case( + [&](auto constIntOp) { return constIntOp.value() >= 0; }) + .Case([&](arith::ConstantOp constOp) { + Value val = constOp.getResult(); + DenseIntElementsAttr constVal; + if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) + return constVal.getSplatValue().isNonNegative(); + return false; + }) + .Case([&](auto pidOp) { return true; }) + .Case([&](auto maxOp) { + // max(a,b) >= 0 iff a>=0 || b>=0 + bool nnLhs = verifyNonNegativeExpr(maxOp.getLhs(), assumptions); + bool nnRhs = verifyNonNegativeExpr(maxOp.getRhs(), assumptions); + return nnLhs || nnRhs; + }) + .Case([&](auto remsiOp) { + // a % b >= 0 iff a>=0 + return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions); + }) + .Case([&](Operation *unaryOp) { + // a = OP b >= 0 iff b >= 0 + return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions); + }) + .Case( + // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when + // OP != sub + [&](Operation *binOp) { + bool nnLhs = + verifyNonNegativeExpr(binOp->getOperand(0), assumptions); + bool nnRhs = + verifyNonNegativeExpr(binOp->getOperand(1), assumptions); + return nnLhs && nnRhs; + }) + .Default([&](Operation *op) { + // Conservatively assume that the expression is negative + return false; + }); + return nonNegative; +} + +// Quick analysis on the Triton IR to decide if we can safely use +// buffer operations +bool canUseBufferOps(Value ptr, const DenseSet &assumptions) { + // 1. Check if the pointer is uniform: i.e., if it comes from a uniform + // pointer(splatted) and non-uniform offset addition + + LDBG("Buffer op checks for: " << ptr); + auto addPtrOp = ptr.getDefiningOp(); + if (!addPtrOp) + return false; + + auto maybeSplatOp = addPtrOp.getPtr().getDefiningOp(); + if (!maybeSplatOp) + return false; + LDBG("Pattern matched"); + + // 2. Check if the offset is a 32-bit tensor + Value offset = addPtrOp.getOffset(); + if (cast(offset.getType()).getElementTypeBitWidth() != 32) + return false; + LDBG("32 bit offset"); + + // 3. Check if the offset is non-negative + if (!verifyNonNegativeExpr(offset, assumptions)) + return false; + + LDBG("Non-negative"); + return true; +} +} // namespace + +struct ConvertTritonLoadToBufferLoad + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonLoadToBufferLoad(mlir::MLIRContext *context, + DenseSet &assumptions) + : mlir::OpRewritePattern(context), + assumptions(assumptions) {} + + mlir::LogicalResult + matchAndRewrite(triton::LoadOp op, PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + + if (op.getCache() != triton::CacheModifier::NONE) + return failure(); + + if (canUseBufferOps(ptr, assumptions)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeOther{}; + if (op.getOther() && !isZeroConst(op.getOther())) + maybeOther = op.getOther(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + + auto bufferLoadOp = rewriter.create( + op->getLoc(), op.getType(), basePtr, tensorOffset, maybeMask, + maybeOther); + + // Propagate `OpIdxAttr` if the currently processed `tt.LoadOp` was + // labeled it. The attribute needs to be preserved for custom instruction + // scheduling. + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + bufferLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); + } + rewriter.replaceOp(op, bufferLoadOp); + + return success(); + } + LDBG("Failed to convert: " << op); + return failure(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; +}; + +struct ConvertTritonStoreToBufferStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonStoreToBufferStore(mlir::MLIRContext *context, + DenseSet &assumptions) + : mlir::OpRewritePattern(context), + assumptions(assumptions) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + + if (op.getCache() != triton::CacheModifier::NONE) + return failure(); + + if (canUseBufferOps(ptr, assumptions)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + rewriter.replaceOpWithNewOp( + op, op.getValue(), basePtr, tensorOffset, maybeMask); + return success(); + } + LDBG("Failed to convert: " << op); + return failure(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; +}; + +class TritonAMDGPUConvertToBufferOpsPass + : public TritonAMDGPUConvertToBufferOpsBase< + TritonAMDGPUConvertToBufferOpsPass> { + +public: + TritonAMDGPUConvertToBufferOpsPass() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + // Collect assumptions in the function + DenseSet assumptions; + m.walk([&](LLVM::AssumeOp op) { + if (op->getOperand(0).getDefiningOp()) + assumptions.insert(op->getOperand(0)); + }); + LDBG("Number of assumptions found: " << assumptions.size()); + + patterns.add(context, assumptions); + patterns.add(context, assumptions); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUConvertToBufferOpsPass() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index d3b2b70f858c..9fce18e21f3c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -235,8 +235,9 @@ std::pair TypesFromMfmaId(mlir::MLIRContext *ctx, return {f8e5m2fnuz, f8e4m3fnuz}; case MfmaTypeId::Bf8Bf8TyId: return {f8e5m2fnuz, f8e5m2fnuz}; + default: + llvm_unreachable("unsupported MfmaTypeId!"); } - assert(false && "unsupported MfmaTypeId"); } FailureOr MfmaInsn::selectMfma(unsigned mDim, unsigned nDim, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index e122f15fd901..6981bd31bdb3 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,23 +5,30 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; namespace ttg = mlir::triton::gpu; -namespace tt = mlir::triton; - -static bool isLocalLoadOrDotLayoutConversion(Operation *op) { - if (isa(op)) - return true; - if (auto cvt = dyn_cast(op)) - return isa(cvt.getType().getEncoding()); - return false; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Return true if the given moduleOp contains a pure matmul problem; i.e., +// single dot in the main loop. +static bool isPureMatmulProblem(triton::FuncOp funcOp) { + bool isMatmul = true; + bool foundLoop = false; + funcOp.walk([&](scf::ForOp forOp) -> void { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; } // Search through block to find earliest insertion point for move op. This can @@ -61,194 +68,323 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -class TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { -public: - TritonAMDGPUReorderInstructionsPass() = default; - - Operation *getFirstUse(Operation *op) { - std::vector users; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - users.push_back(ancestor); - } - auto minOpIt = std::min_element(users.begin(), users.end(), - [](mlir::Operation *a, mlir::Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != users.end() ? *minOpIt : nullptr; - } +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// - void runOnOperation() override { - ModuleOp m = getOperation(); +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(triton::FuncOp funcOp) { + DenseMap opToMove; + funcOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); - // Sink shared memory loads and layout conversions into loops to decrease - // register pressure when possible. - DenseMap opToMove; - m.walk([&](Operation *op) { - if (!isLocalLoadOrDotLayoutConversion(op)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Adjust the placement of shared memory writes and reads to immediately follow +// the definition of their operands in case where shared memory write is in the +// loop but its operand is not. +// +// This is a heuristic driven by optimizing fused attention by hoisting Q tensor +// shared memory read/write operations outside of the loop, as Q is a loop +// invariant and can be loaded once before entering the loop. But it should be +// generally applicable. +// +// There are two possible patterns for this adjustment depending on whether the +// write to shared memory is performed using an optional `local_alloc` argument +// or a `local_store` instruction. +// +// 1) %1 = some_op ... (typically a load or an operation that scales the tensor +// after loading) +// %2 = local_alloc %1 +// %3 = local_load %2 +// +// 2) %1 = some_op ... +// %2 = local_alloc +// %3 = local_store %1, %2 +// %4 = local_load %2 +static void hoistLocalLoad(triton::FuncOp funcOp) { + funcOp.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) return; - opToMove.insert({op, user}); - }); - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); - opToMove.clear(); - - // Adjust the placement of LDS writes and reads to immediately follow the - // definition of their operands in case where LDS write is in the - // loop but it's operand is not. This is a heuristic for optimizing fused - // attention by hoisting Q tensor LDS read/write operations outside of the - // loop, as Q is a loop invariant and can be loaded once before entering the - // loop. - // There are two possible patterns for this adjustment depending on - // whether the write to LDS is performed using an optional `local_alloc` - // argument or a `local_store` instruction. - // - // clang-format off - // - // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) - // %2 = local_alloc %1 - // %3 = local_load %2 - // - // 2) %1 = some_op ... - // %2 = local_alloc - // %3 = local_store %1, %2 - // %4 = local_load %2 - // - // clang-format on - m.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) + + auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) return; - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } - auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { - return; - } + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); + // localStore comes before localLoad in block. + Operation *localStore = getFirstUseInSameBlock(localAlloc); + if (!isa(localStore)) + return; - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - // localStore comes before localLoad in block. - Operation *localStore = getFirstUse(localAlloc); - if (!isa(localStore)) - return; + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); +} - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(triton::FuncOp funcOp) { + SmallVector convertOps; + funcOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); + + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} + +// Move transpositions just after their definition. +static void moveUpTranspose(triton::FuncOp funcOp) { + SmallVector transOps; + funcOp.walk([&](triton::TransposeOpInterface op) { transOps.push_back(op); }); + + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} + +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { + SmallVector moveOps; + // Move local_stores early if dependence distance greater than one iteration. + // Best perf on GEMM when these precede global loads. + funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + bool dontReorder = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Don't hoist control flow as we don't track backtraces of ops within + // their regions. + if (isa(defOp)) { + dontReorder = true; + return false; } - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); - // Sink conversion after the last dealloc but before the first use ancestor - // in its block. This helps to avoid unnecessary shared memory allocation. - m.walk([&](triton::gpu::ConvertLayoutOp op) { - auto curr = mlir::Block::iterator(op); - for (; &*curr != getFirstUse(op); curr++) - if (isa(&*curr)) - op->moveAfter(&*curr); - }); + // If we found ops in the slice we don't want to hoist. + if (dontReorder) + continue; + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; - // Move transpositions just after their definition. - m.walk([&](triton::TransOp op) { - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); - }); + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } + } +} + +//===-------------------------------------------------------------------===// +// Sched-load optimization for matmul kernels with large tile sizes +// The basic idea of sched-load optimization is to sink the 2nd tt.load +// after local_load so that global_load instructions can be interleaved with +// mfma's. This can help hide the issue latency of global_load instructions +// and improve performance on MI300X. +// +// It's assumed that the IR before this optimization has the following +// structure: +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// tileB = tt.load b_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// After this optimization, the IR is transformed to +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// tileB = tt.load b_ptr <-- 2nd tt.load is sinked here +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// For now, we don't have a perfect hueristic about when should this +// optimization be applied. Therefore, we implement a simple hueristic that +// this is applied when the tile size of A and B are large enough, i.e. +// nonKDim >= 128 and kDim >= 64. And also this is only applied for typical +// matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We +// are experimenting how to better control instruction scheduling and enable +// such optimizations. +//===-------------------------------------------------------------------===// +static void sinkSecondLoad(triton::FuncOp funcOp) { + funcOp.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); + }); +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + for (auto funcOp : m.getOps()) { + hoistLocalLoad(funcOp); + + sinkDotConversion(funcOp); + moveDownCoversion(funcOp); + + moveUpTranspose(funcOp); - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than - // one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; - - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); - - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; - - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. - - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); + if (isPureMatmulProblem(funcOp)) { + scheduleGlobalLoadLocalStore(funcOp); + sinkSecondLoad(funcOp); } } } }; +} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp similarity index 64% rename from third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp rename to third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 027f06652f20..d4a6eb09fd46 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1,6 +1,8 @@ #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -17,13 +19,13 @@ // modulo schedule and an expander that rewrites the loop and emits a prologue // and epilogue. This pass first calls a helper that will pre-process the IR // to create stream operations and create a modulo schedule. Then we call the -// expander to generate the prologue and new loop. +// expander to generate the prologue and new loop and epilogue. //===----------------------------------------------------------------------===// #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" -#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2" +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -54,42 +56,117 @@ static Operation *streamPredication(RewriterBase &rewriter, Operation *op, namespace { -// Encapsulate stream pipelining -// For each `scf.for` create a StreamPipeliner manager. +//===----------------------------------------------------------------------===// +// Software pipelining generally works by anchoring on global load ops in the +// main loop and rotating the loop to schedule global load ops for future loop +// iterations together with compute for the current iteration. In this way, we +// can 1) issue memory operations earlier to hide the latency and 2) break the +// strong dependency inside on loop iteration to give backends flexiblity to +// better interleave instructions for better instruction-level parallelism. +// +// This StreamPipeliner class creates the pipelining schedule and calls the +// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule +// consists of multiple stages, where ops from different stages can overlap +// executions because the dependencies are loop carried. +// +// The general flow of this process is: +// +// 1. The user provides a `num_stages` that specifies how many stages the +// pipeline will have. The number of stages must be larger than the distance +// from the first independent load to the compute in order to pipeline. +// 2. A schedule is created based on the distance between the global loads +// in the first stages and the compute that uses the loaded values in the +// last stage (num_stages - 1). Each operation will be clustered in the +// order to best overlap with other operations (see details below in the +// initSchedule method). +// 3. When the compute is a tt.dot, the scheduler will insert a shared +// memory allocation between the global load and tt.dot. The ttg.local_store +// will save the global load value to shared memory and the ttg.local_load +// will load the relevant tiles for the tt.dot. These operations will be +// scheduled according to various scheduling schemes outlined below in the +// initSchedule method (see details there). +// 4. Finally the schedule will be passed to the PipelineExpander to rewrite +// accordingly. The new implementation will consist of: +// a. Prologue: containing the ramp-up of num_stages-1 stages for +// iteratorions i=[0, num_stages-1). +// b. New loop: ordered by cluster and iterated on each operation by +// `i + (num_stages-op_stage)`. +// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the +// ops in stages 1 to last_stage. This must consider that the loop +// bounds may be shorter than num_stages. In this case, the epilogue +// iterations must align with the prologue. +// class StreamPipeliner { public: - StreamPipeliner(scf::ForOp _forOp, int _numStages) - : forOp(_forOp), schedule(_numStages), numStages(_numStages), + StreamPipeliner(scf::ForOp _forOp, int _numStages, bool _prefetch) + : forOp(_forOp), prefetch(_prefetch), numStages(_numStages + prefetch), + schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { options.supportDynamicLoops = true; options.peelEpilogue = true; options.predicateFn = streamPredication; } + LogicalResult pipelineLoop(); + +private: + void initSchedule(int maxIndirectionLevel); + void computeLoadOpsToIndirectionLevelAndUse(); void assignMemoryLayouts(); - void scheduleLoads(DenseSet &rootUsers); + LogicalResult scheduleLoads(DenseSet &rootUsers); void scheduleDependencies(); void scheduleDistanceOneDependencies(); - void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); + void scheduleRemainingToLastStage(); - bool preprocessLoopAndBuildSchedule(); - bool pipelineLoop(); + LogicalResult preprocessLoopAndBuildSchedule(); Value createAlloc(Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned numBuffers); - void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx, - tt::CoarseSchedule::Cluster prefetchCluster); + void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx); void createStreamOps(); + // Define categories of scheduling details per Operation types. + // The StreamPipeliner schedules 5 types of operations: + // 1. GLOBAL_LOAD: tt.load + // 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner) + // 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner) + // 4. COMPUTE: ops that use the loaded data + // 5. TAIL: everything else in the loop + enum SchedType { + SCHED_GLOBAL_LOAD, + SCHED_LOCAL_STORE, + SCHED_LOCAL_LOAD, + SCHED_COMPUTE, + SCHED_TAIL + }; + + void scheduleOp(Operation *op, SchedType type, int stage = -1) { + if (stage < 0) + stage = config[type].stage; + schedule.insert(op, stage, config[type].cluster); + } + private: + // Data members scf::ForOp forOp; - tt::CoarseSchedule schedule; + + // User settings + bool prefetch; int numStages; + // Scheduling clusters + tt::CoarseSchedule schedule; + + // ScheduleConfig lookup by SchedType to get the stage and cluster. + struct ScheduleConfig { + int stage; + tt::CoarseSchedule::Cluster cluster; + }; + SmallVector config; + // Mapping and indirection level for each `tt.load` to its use. - llvm::SmallVector> - loadOpToIndLevelAndUse; + SmallVector> loadOpToIndLevelAndUse; struct LoadInfo { // Shared layout is used for loads feeding into dot ops. @@ -114,9 +191,64 @@ class StreamPipeliner { } // namespace -void StreamPipeliner::createStreamCopy( - tt::LoadOp loadOp, Value alloc, Value extractIdx, - tt::CoarseSchedule::Cluster prefetchCluster) { +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +void StreamPipeliner::initSchedule(int maxIndirectionLevel) { + int lastStage = numStages - 1; + config.resize(SCHED_TAIL + 1); + + bool isMultibuf = numStages > (2 + maxIndirectionLevel); + if (prefetch) { + // Prefetch Schema cluster order and staging. + // for i in (...): + // local_stores: stage=i+1 + // global_loads: stage=i+2 + // compute: stage=i + // local_load: stage=i+1 + // tail: stage=i + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + auto cluster1 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster1}; + config[SCHED_COMPUTE] = {lastStage, cluster1}; + config[SCHED_LOCAL_LOAD] = {lastStage - 1, schedule.clusters.newAtBack()}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } else if (isMultibuf) { + // Streaming Schema cluster order and staging for multi-buffer. + // for i in (...): + // local_stores: stage=i+1 + // global_loads: stage=i+2 + // local_load: stage=i + // compute: stage=i + // tail: stage=i + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + auto cluster1 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster1}; + config[SCHED_LOCAL_LOAD] = {lastStage, cluster1}; + config[SCHED_COMPUTE] = {lastStage, cluster1}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } else { + // Streaming Schema cluster order and staging for single-buffer. + // for i in (...): + // global_loads: stage=i+1 + // local_load: stage=i + // compute: stage=i + // local_stores: stage=i+1 + // tail: stage=i + auto cluster0 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster0}; + config[SCHED_LOCAL_LOAD] = {lastStage, schedule.clusters.newAtBack()}; + config[SCHED_COMPUTE] = {lastStage, cluster0}; + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } +} + +void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -124,8 +256,9 @@ void StreamPipeliner::createStreamCopy( Location loc = loadOp.getLoc(); Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); - tt::MemDescType allocTy = cast(alloc.getType()); + ttg::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); Operation *copy = builder.clone(*loadOp); @@ -138,44 +271,51 @@ void StreamPipeliner::createStreamCopy( loadOffsets[0] = extractIdx; auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); - auto subviewTy = tt::MemDescType::get( + auto subviewTy = ttg::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); - auto storeOp = - builder.create(loc, copy->getResult(0), viewLoad); // Clean up old local caches. SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { if (auto alloc = dyn_cast(user)) { - alloc.replaceAllUsesWith(viewLoad.getResult()); + triton::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); allocsToErase.push_back(alloc); } } for (auto alloc : allocsToErase) alloc.erase(); + // Prefetch load ahead of the dot stage if is used by the dot. + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + + // Create local load auto sharedLoad = builder.create(loc, loadOp.getType(), viewLoad); - auto result = sharedLoad->getResults(); - - // Create a select for non-zero other values. - Value other = loadOp.getOther(); - if (other && !isZeroConst(other)) { - auto select = builder.create( - loc, loadOp.getType(), mask, sharedLoad.getResult(), other); - result = select->getResults(); + Value result = sharedLoad.getResult(); + if (prefetch) + scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } - loadOp->replaceAllUsesWith(result); + loadOp->replaceAllUsesWith(ValueRange{result}); - // Prefetch load ahead of the dot stage if is used by the dot. - if (loadToInfo[loadOp].usedByDot) { - assert(numStages >= 2 && "requires num_stages=2 at least"); - schedule.insert(storeOp, numStages - 2, prefetchCluster); - schedule.insert(viewLoad, numStages - 2, prefetchCluster); + if (prefetch && result.hasOneUse()) { + if (auto cvt = dyn_cast(*result.getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); } + loadOp.erase(); } @@ -190,7 +330,7 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { if (user->getNumResults() != 1) return std::nullopt; if (auto memDesc = - dyn_cast(user->getResult(0).getType())) { + dyn_cast(user->getResult(0).getType())) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); @@ -200,10 +340,11 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { if (!isa(user)) return std::nullopt; auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()).getEncoding()); + cast(user->getResult(0).getType()) + .getEncoding()); if (!dotOpEnc) return std::nullopt; - auto srcTy = cast(val.getType()); + auto srcTy = cast(val.getType()); auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); @@ -337,7 +478,7 @@ void StreamPipeliner::assignMemoryLayouts() { } } -void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { +LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Get all loads that are (transitively) used by dot ops and their distance // to the dot op. computeLoadOpsToIndirectionLevelAndUse(); @@ -350,12 +491,12 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { } }); if (loadOpToIndLevelAndUse.empty()) - return; + return failure(); // Check which loads are good for pipelining, and assign them memory layouts. assignMemoryLayouts(); if (loadToInfo.empty()) - return; + return failure(); // Filter out load ops that cannot be pipelined. int resize = 0; @@ -371,6 +512,12 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + LDBG("maxIndirectionLevel = " << maxIndirectionLevel); + if (maxIndirectionLevel >= numStages) + return failure(); + + initSchedule(maxIndirectionLevel); + // The stage gap between chained loads--this allows us to "spread" loads // with a non-one step in case the number of stages given by the user is // large. @@ -380,24 +527,18 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG("stagesBetweenLoads = " << stagesBetweenLoads); // Put the root uses of the loads in the last stage. - tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { - schedule.insert(use, numStages - 1, rootUsersCluster); + scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } } - // Create a cluster for load ops at each indirection level. - SmallVector loadsClusters; - for (int i = 0; i <= maxIndirectionLevel; i++) { - loadsClusters.push_back(schedule.clusters.newAtBack()); - } // Assign stages to the loads. for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - schedule.insert(loadOp, stage, loadsClusters[indLevel]); + scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); } // Calculate distance from the load to the use. @@ -413,6 +554,8 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG(" usedByDot: " << info.usedByDot); } }); + + return success(); } // Add dependencies of anchor ops to the coarse schedule. Schedule them to @@ -479,22 +622,23 @@ void StreamPipeliner::scheduleDistanceOneDependencies() { } } -void StreamPipeliner::scheduleRemainingToLastStage( - tt::CoarseSchedule::Cluster afterPrologue) { +void StreamPipeliner::scheduleRemainingToLastStage() { + int lastStage = numStages - 1; // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { - opToCluster[&op] = afterPrologue; + auto schedType = isa(op) ? SCHED_COMPUTE : SCHED_TAIL; + opToCluster[&op] = config[schedType].cluster; } } SmallVector queue; for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { // We really only care about the producers from the last stage. // Others will be scheduled before these ops anyway. - if (stage == numStages - 1) { + if (stage == lastStage) { queue.push_back(op); } } @@ -512,7 +656,7 @@ void StreamPipeliner::scheduleRemainingToLastStage( } } for (auto [op, cluster] : opToCluster) { - schedule.insert(op, numStages - 1, cluster); + schedule.insert(op, lastStage, cluster); } } @@ -526,11 +670,13 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, auto ty = cast(loadOp->getResultTypes()[0]); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), numBuffers); - Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), - sharedEnc, sharedMemorySpace, - /*mutableMemory=*/true); - return builder.create(loadOp->getLoc(), memdescType, - Value()); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + auto alloc = + builder.create(loadOp->getLoc(), memdescType, Value()); + sharedMemAllocs.push_back(alloc); + return alloc; } // Convert load ops into shared memory allocation loads and apply @@ -538,19 +684,20 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, void StreamPipeliner::createStreamOps() { // Calculate the number of buffers needed for each load. // TODO: Use the precise number of buffers needed by the particular load. - int numBuffers = -1; - for (auto &[_, info] : loadToInfo) - numBuffers = std::max(numBuffers, info.distToUse); - LDBG("deduced shared memory buffer number = " << numBuffers); + int maxNumBuffers = -1; + for (auto &[_, info] : loadToInfo) { + int sharedBuffers = info.distToUse - (info.usedByDot ? prefetch : 0); + maxNumBuffers = std::max(maxNumBuffers, sharedBuffers); + } + LDBG("deduced max shared memory buffer number = " << maxNumBuffers); SmallVector> loadToAllocs; for (auto &[loadOp, info] : loadToInfo) { if (!info.sharedEncoding) continue; - Value alloc = createAlloc(loadOp, info.sharedEncoding, numBuffers); + Value alloc = createAlloc(loadOp, info.sharedEncoding, maxNumBuffers); assert(alloc && "Failed to create alloc for the async load."); - sharedMemAllocs.push_back(alloc); loadToAllocs.emplace_back(loadOp, alloc); } @@ -563,7 +710,7 @@ void StreamPipeliner::createStreamOps() { Value one = builder.create(loc, 1, 32); Value extractIdx = minusOne; Value numBuffersVal = - builder.create(loc, numBuffers, 32); + builder.create(loc, maxNumBuffers, 32); unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. @@ -582,24 +729,23 @@ void StreamPipeliner::createStreamOps() { extractIdx, numBuffersVal); extractIdx = builder.create(loc, cndExt, extractIdx, zero); - // Create a cluster for prefetching global reads for the dot. - tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); - + // Create stream copies. for (auto &[op, alloc] : loadToAllocs) { if (auto loadOp = dyn_cast(op)) - createStreamCopy(loadOp, alloc, extractIdx, prefetchCluster); + createStreamCopy(loadOp, alloc, extractIdx); } // Patch the yield with the updated counters. appendToForOpYield(forOp, {extractIdx}); } -bool StreamPipeliner::preprocessLoopAndBuildSchedule() { +LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; - scheduleLoads(rootUsers); + if (failed(scheduleLoads(rootUsers))) + return failure(); if (loadToInfo.empty()) - return false; + return failure(); LLVM_DEBUG({ LDBG("Coarse schedule loads only:"); @@ -609,13 +755,6 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { // Convert the loads into shared memory allocations and loads from them. createStreamOps(); - LLVM_DEBUG({ - LDBG("Coarse schedule with stream loads:"); - schedule.dump(); - }); - - tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); - scheduleDependencies(); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); @@ -628,7 +767,7 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { schedule.dump(); }); - scheduleRemainingToLastStage(afterPrologue); + scheduleRemainingToLastStage(); LLVM_DEBUG({ LDBG("Final coarse schedule:"); schedule.dump(); @@ -651,7 +790,18 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { // Explicitly deallocate created allocations. for (auto alloc : sharedMemAllocs) builder.create(forOp.getLoc(), alloc); - return true; + + return success(); +} + +LogicalResult StreamPipeliner::pipelineLoop() { + if (failed(preprocessLoopAndBuildSchedule())) + return failure(); + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return tt::pipelineForLoop(rewriter, forOp, options); } // Return true if the preconditions for pipelining the loop are met. @@ -671,35 +821,64 @@ static bool checkPrecondition(scf::ForOp forOp) { return !forOp->walk(hasNestedLoopInside).wasInterrupted(); } -bool StreamPipeliner::pipelineLoop() { - if (!checkPrecondition(forOp)) - return false; +namespace { +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template Operation *passPrevUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> Operation * { + if (auto defOp = value.getDefiningOp()) { + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + } + return nullptr; + }; - if (!preprocessLoopAndBuildSchedule()) - return false; - LDBG("Loop before sending to expander:\n" << *forOp); + auto unaryOp = getNextUnaryOps(value); + while (unaryOp) { + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + unaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return nullptr; +} - IRRewriter rewriter(forOp->getContext()); - rewriter.setInsertionPoint(forOp); - return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. +void labelLoadOpsForTritonDot(scf::ForOp forOp) { + mlir::MLIRContext *ctx = forOp->getContext(); + if (auto dotOp = triton::getSingleDotOpIfExists(forOp)) { + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + if (auto loadOp = passPrevUnaryOps(dotOperand)) { + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } } -namespace { -struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { +struct PipelinePass : public TritonAMDGPUStreamPipelineBase { PipelinePass() = default; - PipelinePass(int32_t numStages) { this->numStages = numStages; } + PipelinePass(int32_t numStages, int32_t prefetch) { + this->numStages = numStages; + this->prefetch = prefetch; + } void runOnOperation() override { SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { + labelLoadOpsForTritonDot(forOp); // Bail out for loops with num_stage <= 1. if (getNumStagesOrDefault(forOp) > 1) loops.push_back(forOp); }); for (scf::ForOp forOp : loops) { - StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp)); - sp.pipelineLoop(); + if (!checkPrecondition(forOp)) + continue; + StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp), prefetch); + if (failed(sp.pipelineLoop())) + continue; } } @@ -712,9 +891,9 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { return numStages; } }; -} // anonymous namespace +} // namespace -std::unique_ptr -mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { - return std::make_unique(numStages); +std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass(int numStages, + int prefetch) { + return std::make_unique(numStages, prefetch); } diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py new file mode 100644 index 000000000000..a9c7df4754b8 --- /dev/null +++ b/third_party/amd/python/test/test_extract_slice.py @@ -0,0 +1,115 @@ +import tempfile + +import numpy as np +import pytest +import torch + +import triton +import triton.language as tl + +from triton._internal_testing import is_hip + +num_ctas_list = [1] + +GPU_DIALECT = "ttg" + +if is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test extract slice +# ----------------------- + +extract_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("extract_layout", extract_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): + if not is_hip(): + pytest.skip("extract_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #extract_layout = {extract_layout} + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + extract_slice) + assert test_result diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a9f3a8ee2f60..8132773fc2a1 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -41,15 +41,17 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { [](mlir::PassManager &pm, const std::string &arch, bool ftz) { pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); }); - m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(createConvertBuiltinFuncToLLVMPass()); + m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm, bool ftz) { + pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); }); m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createInsertInstructionSchedHintsPass()); + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", - [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createLowerInstructionSchedHintsPass(variant)); + [](mlir::PassManager &pm, const std::string &arch, int32_t numStages, + const std::string &variant) { + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass( + arch, numStages, variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { @@ -66,10 +68,12 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUOptimizeEpiloguePass); ADD_PASS_WRAPPER_0("add_canonicalize_pointers", mlir::createTritonAMDGPUCanonicalizePointersPass); + ADD_PASS_WRAPPER_0("add_convert_to_buffer_ops", + mlir::createTritonAMDGPUConvertToBufferOpsPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_1("add_stream_pipelinev2", - mlir::createTritonAMDGPUStreamPipelineV2Pass, int); + ADD_PASS_WRAPPER_2("add_stream_pipeline", + mlir::createTritonAMDGPUStreamPipelinePass, int, int); } void addControlConstant(llvm::Module *module, const char *name, @@ -157,6 +161,24 @@ void init_triton_amd(py::module &&m) { module->eraseNamedMetadata(openclVersion); }); + m.def("disable_print_inline", [](llvm::Module *module) { + // List of functions name prefixes we want to forbid inline. + std::array prefixes = {"__ockl_fprintf", "__ockl_printf"}; + + for (llvm::Function &f : module->functions()) { + if (!f.hasName()) + continue; + llvm::StringRef name = f.getName(); + + auto isNamePrefixed = [&name](const char *prefix) { + return name.starts_with(prefix); + }; + + if (llvm::any_of(prefixes, isNamePrefixed)) + f.addFnAttr(llvm::Attribute::NoInline); + } + }); + m.def( "assemble_amdgcn", [](const std::string &assembly, const std::string &arch, diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt new file mode 100644 index 000000000000..ef558f0edae1 --- /dev/null +++ b/third_party/cpu/CMakeLists.txt @@ -0,0 +1,34 @@ +# libxsmm +include(xsmm) +message (STATUS "LIBXSMM Include dir: ${XSMM_INCLUDE_DIRS}") + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms MLIRAMXToLLVMIRTranslation TritonCPUXsmm) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms PRIVATE Python3::Module pybind11::headers) +endif() + +add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) +target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) + +add_library(TritonCPUXsmmRuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/lib/Xsmm/runtime/XsmmRunnerUtils.cpp) +target_link_libraries(TritonCPUXsmmRuntime PRIVATE xsmm) +set_property(TARGET TritonCPUXsmmRuntime PROPERTY CXX_STANDARD 11) +target_compile_definitions(TritonCPUXsmmRuntime PRIVATE mlir_c_runner_utils_EXPORTS) +target_include_directories(TritonCPUXsmmRuntime + PUBLIC + $ +) + +# Build and link sleef +set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) +set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE) +set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE) +set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE) +set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will not be built." FORCE) +add_subdirectory("${CMAKE_SOURCE_DIR}/third_party/sleef" sleef) +# Override sleef's output directory with our own +set_target_properties(sleef PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/third_party/cpu/backend/__init__.py b/third_party/cpu/backend/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py new file mode 100644 index 000000000000..d3b69918d41c --- /dev/null +++ b/third_party/cpu/backend/compiler.py @@ -0,0 +1,315 @@ +import functools +import hashlib +import os +import tempfile +from pathlib import Path + +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Dict, Optional, Tuple + +from triton._C.libtriton import cpu, ir, llvm, passes +from triton.backends.compiler import BaseBackend, GPUTarget +from triton.runtime.build import _build +import triton.backends.cpu.driver as cpu_driver + + +def min_dot_size(target: GPUTarget): + # Other architectures will only support 16,16,16 + return lambda lhsType, rhsType: (4, 4, 4) + + +VecLib = cpu.passes.ttcpuir.VecLib + + +@dataclass(frozen=True) +class CPUOptions: + # GPU-specific options are used in several places. + # For now, we just provide dummy values. + backend_name: str = "cpu" + # These options provide compatibility with GPU kernel calls. + # All of them are ignored. + num_warps: int = 0 + num_stages: int = 0 + num_ctas: int = 0 + # Max number of threads to be used for a kernel call. + # Zero value is used to utilize all available CPU cores. + num_threads: int = 0 + cluster_dims: tuple = (1, 1, 1) + extern_libs: dict = None + debug: bool = False + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv") + deprecated_fp8_dtypes: Tuple[str] = () + allowed_dot_input_precisions: Tuple[str] = ("ieee", "tf32", "tf32x3") + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + enable_fp_fusion: bool = True + max_num_imprecise_acc_default: int = 0 + enable_fast_math: bool = True + enable_vector_xsmm: bool = False + enable_triton_xsmm: bool = False + enable_loop_brgemm_xsmm: bool = False + enable_raise_block_pointer: bool = False + vec_lib: Optional[str] = 'libsleef' + # TODO: Try to enable it. + sanitize_overflow: bool = False + + # TODO: We may introduce CPU-specific options like # of cores. + + def __post_init__(self): + pass + + def hash(self): + hash_dict = dict(self.__dict__) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def get_vec_lib(self) -> VecLib: + if self.vec_lib is None: + return None + # Parse enum from str here (instead of in parse_options()) because the options have to be JSON-serializable, + # and pybind enums are not serializable. + vec_lib = VecLib.__members__.get(self.vec_lib, None) + if vec_lib is None: + raise ValueError( + f"Unexpected value for vec_lib: {self.vec_lib}, should be one of {{{', '.join(VecLib.__members__.keys())}}}" + ) + return vec_lib + + +class CPUBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "cpu" + + def __init__(self, target: tuple) -> None: + super().__init__(target) + self.binary_ext = "so" + self.cpu_arch = llvm.get_cpu_tripple().split("-")[0] + self.cpu_name = llvm.get_cpu_name() + self.cpu_features = llvm.get_cpu_features() + if 'amx-tile' in self.cpu_features: + if not cpu.enable_amx(): + import warnings + warnings.warn("Warning! Couldn't enable AMX for the process. AMX optimizations are disabled.") + self.cpu_features.discard('amx-tile') + self.cpu_features.discard('amx-int8') + self.cpu_features.discard('amx-fp16') + self.cpu_features.discard('amx-bf16') + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} + if "enable_fast_math" not in args: + args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "1") != "0" + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(CPUOptions.supported_fp8_dtypes) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + if "enable_vector_xsmm" not in args: + args["enable_vector_xsmm"] = os.getenv("TRITON_CPU_VECTOR_XSMM", "0") != "0" + if "enable_triton_xsmm" not in args: + args["enable_triton_xsmm"] = os.getenv("TRITON_CPU_TRITON_XSMM", "0") != "0" + if "enable_loop_brgemm_xsmm" not in args: + args["enable_loop_brgemm_xsmm"] = os.getenv("TRITON_CPU_LOOP_BRGEMM_XSMM", "0") != "0" + if "enable_raise_block_pointer" not in args: + args["enable_raise_block_pointer"] = os.getenv("TRITON_CPU_RAISE_BLOCK_POINTER", "0") != "0" + return CPUOptions(**args) + + def pack_metadata(self, metadata): + return metadata + + def get_codegen_implementation(self): + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + return codegen_fns + + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.cpu import libdevice + return {"triton.language.extra.libdevice": libdevice} + + def load_dialects(self, ctx): + cpu.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + # This is the same as the Nvidia backend. + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttcir(mod, metadata, opt): + # TTIR -> TTCIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if opt.enable_raise_block_pointer: + cpu.passes.ttcpuir.add_raise_block_pointer(pm) + if opt.enable_loop_brgemm_xsmm: + cpu.passes.ttcpuir.add_loop_to_brgemm_xsmm(pm) + passes.common.add_canonicalizer(pm) + if opt.enable_triton_xsmm: + cpu.passes.ttcpuir.add_convert_triton_to_xsmm(pm) + passes.common.add_canonicalizer(pm) + cpu.passes.ttcpuir.add_scalarize(pm, True) + cpu.passes.ttcpuir.add_convert_memory_ops(pm, True) + cpu.passes.ttcpuir.add_convert_ptr_ops(pm) + cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) + cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) + cpu.passes.ttcpuir.add_convert_dot_op(pm) + cpu.passes.ttcpuir.add_convert_histogram_op(pm) + cpu.passes.ttcpuir.add_convert_reduction_op(pm, True, False) + cpu.passes.ttcpuir.add_convert_scan_op(pm) + cpu.passes.ttcpuir.add_convert_cf_ops(pm) + cpu.passes.ttcpuir.add_convert_atomic_ops(pm) + cpu.passes.ttcpuir.add_convert_debug_ops(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) + return mod + + def make_tttcir(self, mod, metadata, opt): + # TTCIR -> Target TTCIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + cpu.passes.ttcpuir.add_triton_cpu_canonicalizer(pm) + cpu.passes.ttcpuir.add_optimize_masks(pm) + passes.common.add_canonicalizer(pm) + convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") + and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features) + if convert_bf16_dot_product: + use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" + cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum) + if 'amx-tile' in self.cpu_features: + amx_int8 = 'amx-int8' in self.cpu_features + # amx_fp16 = 'amx-fp16' in self.cpu_features + # FP16 support is not in AMX dialect yet + amx_fp16 = False + amx_bf16 = 'amx-bf16' in self.cpu_features + cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, amx_int8, amx_fp16, amx_bf16) + if 'avx512f' in self.cpu_features: + cpu.passes.ttcpuir.add_convert_dot_to_fma(pm) + cpu.passes.ttcpuir.add_convert_dot_generic(pm) + promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features + # We don't have any lowering for mixed precision matmuls, so always use casts for now + convert_mixed_precision_matmul = True + # We don't have math lib functions for FP8, FP16, BF16. Promote such operations to FP32. + promote_lib_math_to_fp32 = True + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul, + promote_lib_math_to_fp32) + decompose_bf16_conv = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features + decompose_fp8_conv = True + cpu.passes.ttcpuir.add_decompose_fp_conversions(pm, decompose_bf16_conv, decompose_fp8_conv) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + def make_llir(self, src, metadata, options): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + metadata["threads_per_warp"] = 1 + mod = src + # TritonCPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if options.enable_vector_xsmm: + cpu.passes.ttcpuir.add_convert_vector_to_xsmm(pm) + cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) + cpu.passes.ttcpuir.add_expand_strided_metadata(pm) + cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) + cpu.passes.ttcpuir.add_lower_affine(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + cpu.passes.ttcpuir.add_func_op_to_llvmir(pm) + cpu.passes.ttcpuir.add_program_id_to_llvmir(pm) + cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) + cpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) + cpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) + + vec_lib_requirements = { + VecLib.libsleef: {"neon", "sse", "avx"}, + VecLib.libmvec: {"avx512f"}, + } + if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features: + cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, self.cpu_features) + + passes.convert.add_math_to_llvmir(pm) + cpu.passes.ttcpuir.add_math_to_libm(pm) + cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) + cpu.passes.ttcpuir.add_memref_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + cpu.passes.ttcpuir.add_func_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + + # Find kernel fn + kernel_names = cpu.find_kernel_names(mod) + assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" + + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + if llvm_mod is None: + raise RuntimeError("Failed to convert to LLVM IR") + llvm.set_host_target(llvm_mod) + #if options.extern_libs: + # paths = [path for (name, path) in options.extern_libs] + # llvm.link_extern_libs(llvm_mod, paths) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + # Get some metadata + metadata["shared"] = 0 + metadata["name"] = kernel_names[0] + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_asm(src, metadata, options): + return llvm.translate_to_host_asm(src, options.enable_fp_fusion, options.enable_fast_math) + + @staticmethod + def make_so(src, metadata, options): + with tempfile.TemporaryDirectory() as tmpdir: + asm_path = os.path.join(tmpdir, "kernel.s") + Path(asm_path).write_text(src) + lib_dirs = cpu_driver.library_dirs + libs = ["gcc", "m", "TritonCPURuntime", "sleef"] + if options.enable_vector_xsmm or options.enable_triton_xsmm or options.enable_loop_brgemm_xsmm: + libs.extend(["xsmm", "TritonCPUXsmmRuntime"]) + so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) + with open(so, "rb") as f: + return f.read() + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) + stages["tttcir"] = lambda src, metadata: self.make_tttcir(src, metadata, options) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) + stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) + stages["so"] = lambda src, metadata: self.make_so(src, metadata, options) + + @functools.lru_cache() + def hash(self): + # TODO: Get more detailed CPU info like raw brand name with supported ISAs. + # Right now it would only return a simple string like "x86_64" or "aarch64". + import platform + + return f"{platform.machine()}" diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py new file mode 100644 index 000000000000..3308fd23c680 --- /dev/null +++ b/third_party/cpu/backend/driver.py @@ -0,0 +1,469 @@ +import os +import hashlib +import importlib +import importlib.resources +import tempfile +import time + +import triton +import triton._C +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget + +from pathlib import Path +from triton._C.libtriton import llvm + +_dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") +# for locating libTritonCPURuntime +try: + _triton_C_dir = importlib.resources.files(triton).joinpath("_C") +except AttributeError: + # resources.files() doesn't exist for Python < 3.9 + _triton_C_dir = importlib.resources.path(triton, "_C").__enter__() + +include_dirs = [] +library_dirs = [_triton_C_dir] +libraries = ["stdc++"] + +# Skip non-existent paths +sys_include_dir = os.path.join(_dirname, "include") +if os.path.exists(sys_include_dir): + include_dirs.append(sys_include_dir) + +sys_lib_dir = os.path.join(_dirname, "lib") +if os.path.exists(sys_lib_dir): + library_dirs.append(sys_lib_dir) + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dirs, include_dirs, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class CPUUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + pass + + def load_binary(self, name, kernel, shared_mem, device): + with tempfile.NamedTemporaryFile(mode="wb", suffix=".so") as f: + f.write(kernel) + f.flush() + import ctypes + lib = ctypes.cdll.LoadLibrary(f.name) + fn_ptr = getattr(lib, name) + fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value + return (lib, fn_ptr_as_void_p, 0, 0) + + def get_device_properties(self, *args): + return {"max_shared_mem": 0} + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOKOOOO" + args_format + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) + kernel_fn_args = [i for i in signature.keys() if i not in constants] + kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) + kernel_fn_arg_types = ', '.join([f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args] + ["uint32_t"] * 6) + + # generate glue code + src = f""" +#include +#include +#include +#include +#include +#include +#ifdef _OPENMP +#include +#endif // _OPENMP +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +inline bool getBoolEnv(const std::string &env) {{ + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) {{ return std::tolower(c); }}); + return str == "on" || str == "true" || str == "1"; +}} + +inline std::optional getIntEnv(const std::string &env) {{ + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return std::nullopt; + + char *endptr; + long int result = std::strtol(cstr, &endptr, 10); + if (endptr == cstr) + assert(false && "invalid integer"); + return result; +}} + +using kernel_ptr_t = void(*)({kernel_fn_arg_types}); + +typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static std::unique_ptr get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{ + std::unique_ptr grids(new uint32_t[gridX * gridY * gridZ][3]); + // TODO: which order would be more effective for cache locality? + for (uint32_t z = 0; z < gridZ; ++z) {{ + for (uint32_t y = 0; y < gridY; ++y) {{ + for (uint32_t x = 0; x < gridX; ++x) {{ + grids[z * gridY * gridX + y * gridX + x][0] = x; + grids[z * gridY * gridX + y * gridX + x][1] = y; + grids[z * gridY * gridX + y * gridX + x][2] = z; + }} + }} + }} + return grids; +}} + +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_threads, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: Consider using omp collapse(3) clause for simplicity? + size_t N = gridX * gridY * gridZ; + if (N == 1) {{ + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} 0, 0, 0, 1, 1, 1); + return; + }} + + auto all_grids = get_all_grids(gridX, gridY, gridZ); + int omp_max_threads = 1; + #ifdef _OPENMP + omp_max_threads = omp_get_max_threads(); + #endif // _OPENMP + int max_threads = (num_threads > 0) ? num_threads : omp_max_threads; + + // Don't pay OMP overhead price when a single thread is used. + if (max_threads == 1) {{ + for (size_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); + }} + return; + }} + + // For now, use the default chunk size, total iterations / max_threads. +#ifdef _OPENMP +#pragma omp parallel for schedule(static) num_threads(max_threads) +#endif // _OPENMP + for (size_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *py_obj_stream; + void* pKrnl; + + {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ + return NULL; + }} + + void *pStream = PyLong_AsVoidPtr(py_obj_stream); + kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); + + // Extract num_threads metadata. + int num_threads = 0; + PyObject *num_threads_attr = PyObject_GetAttrString(kernel_metadata, "num_threads"); + if (num_threads_attr && PyLong_Check(num_threads_attr)) + num_threads = PyLong_AsLong(num_threads_attr); + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + run_omp_kernels(gridX, gridY, gridZ, num_threads, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_cpu_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class CPULauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_cpu_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CPUDeviceInterface: + + class HooksTimeAccessor: + + def __init__(self, di): + self.di = di + self.record_idx = 0 + + def elapsed_time(self, end_event) -> float: + total_time = 0 + for i in range(self.record_idx, end_event.record_idx): + total_time += self.di.kernel_times[i] + return total_time * 1000 + + def record(self): + self.record_idx = len(self.di.kernel_times) + + class TimerEvent: + + def __init__(self): + self.timer = 0 + + def elapsed_time(self, end_event) -> float: + return (end_event.timer - self.timer) * 1000 + + def record(self): + self.timer = time.perf_counter() + + def __init__(self): + self.kernel_times = [] + self.last_start = 0 + self.use_hooks = False + triton.compiler.CompiledKernel.launch_enter_hook = None + triton.compiler.CompiledKernel.launch_exit_hook = None + + def enable_hook_timing(self): + self.use_hooks = True + triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook() + triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook() + + def synchronize(self): + pass + + def _enter_hook(self): + self.last_start = time.perf_counter() + + def _exit_hook(self): + self.kernel_times.append(time.perf_counter() - self.last_start) + + def Event(self, enable_timing=True): + if self.use_hooks: + return CPUDeviceInterface.HooksTimeAccessor(self) + return CPUDeviceInterface.TimerEvent() + + +class CPUDriver(DriverBase): + + def __init__(self): + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + super().__init__() + + def get_current_device(self): + return 0 + + def get_current_stream(self, device): + return 0 + + def get_current_target(self): + # Capability and warp size are zeros for CPU. + # TODO: GPUTarget naming isn't obviously good. + cpu_arch = llvm.get_cpu_tripple().split("-")[0] + return GPUTarget("cpu", cpu_arch, 0) + + def get_device_interface(self): + return CPUDeviceInterface() + + @staticmethod + def is_active(): + return True + + def get_benchmarker(self): + from triton.testing import do_bench + + def do_bench_cpu(*args, **kwargs): + if not 'measure_time_with_hooks' in kwargs: + kwargs['measure_time_with_hooks'] = True + return do_bench(*args, **kwargs) + + return do_bench_cpu + + def get_empty_cache_for_benchmark(self): + import torch + + # A typical LLC size for high-end server CPUs are ~400MB. + cache_size = 512 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu') diff --git a/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h new file mode 100644 index 000000000000..838ecebb6add --- /dev/null +++ b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h @@ -0,0 +1,107 @@ +#ifndef TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H +#define TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton::cpu { + +// Lattice value to hold a shape and strides for a tensor pointer. +// If multiple size or stride values are possible for some dimension +// then ShapedType::kDynamic is used for that dimension. +class TensorPtrShapeInfo { +public: + TensorPtrShapeInfo() = default; + + TensorPtrShapeInfo(ArrayRef shape, ArrayRef strides) + : shape(shape), strides(strides) { + assert(shape.size() == strides.size()); + } + + ArrayRef getShape() const { return shape; } + ArrayRef getStrides() const { return strides; } + + int64_t getRank() const { return static_cast(shape.size()); } + int64_t getSize(int64_t dim) const { return shape[dim]; } + int64_t getStride(int64_t dim) const { return strides[dim]; } + + bool operator==(const TensorPtrShapeInfo &other) const { + return shape == other.shape && strides == other.strides; + } + + static TensorPtrShapeInfo join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs); + + static TensorPtrShapeInfo getPessimisticValueState(Value value); + + void print(raw_ostream &os) const { + os << "shape = ["; + llvm::interleaveComma(shape, os); + os << "], strides = ["; + llvm::interleaveComma(strides, os); + os << "]"; + } + +private: + SmallVector shape; + SmallVector strides; +}; + +using TensorPtrShapeInfoMapT = DenseMap; +class ModuleTensorPtrShapeInfoAnalysis + : public CallGraph { +public: + explicit ModuleTensorPtrShapeInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, TensorPtrShapeInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast(callOp.resolveCallable()); + update(callOp, callee); + }); + } + } + + TensorPtrShapeInfo *getPtrShapeInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton::cpu + +#endif // TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt new file mode 100644 index 000000000000..e8f4b7574f1d --- /dev/null +++ b/third_party/cpu/include/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(ScalarizePass) +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) +add_subdirectory(TritonRaiseBlockPointer) +add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/include/ScalarizePass/CMakeLists.txt b/third_party/cpu/include/ScalarizePass/CMakeLists.txt new file mode 100644 index 000000000000..4af0f9490fb9 --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS ScalarizeInterface.td) +mlir_tablegen(ScalarizeInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(ScalarizeInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ScalarizeInterfaceIncGen) diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h new file mode 100644 index 000000000000..1b16ff935540 --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h @@ -0,0 +1,33 @@ +#ifndef MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ +#define MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/IR/OpDefinition.h" + +/// Include the ODS generated interface header files. +#include "cpu/include/ScalarizePass/ScalarizeInterface.h.inc" + +namespace mlir { +namespace triton { +namespace cpu { + +mlir::Value computeScalarValue(mlir::Operation *scalarizationOp, + mlir::Value vals, + mlir::ArrayRef indices, + mlir::PatternRewriter &rewriter); + +mlir::Value computeScalarValue(mlir::Operation *scalarizationOp, + mlir::Value vals, mlir::ValueRange indices, + mlir::PatternRewriter &rewriter); + +bool canComputeScalarValue(mlir::Value vals); +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif // MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td new file mode 100644 index 000000000000..7e6c4acecbcb --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td @@ -0,0 +1,52 @@ +#ifndef MLIR_SCALARIZEINTERFACE +#define MLIR_SCALARIZEINTERFACE + +include "mlir/IR/OpBase.td" + +def ScalarizeInterface : OpInterface<"ScalarizeInterface"> { + let description = [{ + Interface for allowing operations to expose information needed to + scalarize them or in simpler terms inserts SCF loops to reduce amount of + generated ir. Similar with checking operands of specific operations for + constancy - to understand is it possible to put it inside of loop's body. + }]; + let cppNamespace = "mlir::triton::cpu"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Checks operand and is ScalarizeInterface registered for this operation. + }], + /*retType=*/"bool", + /*methodName=*/"canComputeScalarValue", + /*args=*/(ins + "mlir::Value ":$vals) + >, + InterfaceMethod< + /*desc=*/[{ + Returns value that can be put inside of generated cycle and creates required constants. + Can go throught operands to check type of passed values. Implementation for static indeces. + }], + /*retType=*/"mlir::Value", + /*methodName=*/"computeScalarValue", + /*args=*/(ins + "mlir::Value ":$vals, + "mlir::ArrayRef ":$indices, + "mlir::PatternRewriter &":$rewriter) + >, + InterfaceMethod< + /*desc=*/[{ + Returns value that can be put inside of generated cycle and creates required constants. + Can go throught operands to check type of passed values. Implementation for dynamic indices + which is in common used in loops to iterate with Inductional Variable. + }], + /*retType=*/"mlir::Value", + /*methodName=*/"computeScalarValueForLoop", + /*args=*/(ins + "mlir::Value ":$vals, + "mlir::ValueRange ":$indices, + "mlir::PatternRewriter &":$rewriter) + > + ]; +} + +#endif // MLIR_SCALARIZEINTERFACE diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h b/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h new file mode 100644 index 000000000000..ab2730d9acf8 --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h @@ -0,0 +1,16 @@ +#ifndef MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H +#define MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace triton { +namespace cpu { + +void registerTritonOpScalarizeExternalModels(DialectRegistry ®istry); + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif // MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..0936dff12d91 --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) +add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h new file mode 100644 index 000000000000..cc29821c580c --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H +#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +enum class VecLib { + Mvec, + Sleef, +}; + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +std::unique_ptr> createFuncOpToLLVMPass(); +std::unique_ptr> createMemoryOpToLLVMPass(); +std::unique_ptr> createGetProgramIdOpToLLVMPass(); +std::unique_ptr> createLowerMultiReductionPass(); +std::unique_ptr> createAtomicOpsToLLVMPass(); +std::unique_ptr> createDebugOpsToLLVMPass(); +std::unique_ptr> +createMathToVecLibPass(VecLib lib = VecLib::Sleef, + std::set cpu_features = {}); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td new file mode 100644 index 000000000000..3ee08d9968b2 --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -0,0 +1,88 @@ +#ifndef TRITONCPU_CONVERSION_PASSES +#define TRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert FuncOp to LLVM for CPU."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton memory operations to LLVM for CPU."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton GetProgramId to LLVM for CPU."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect"]; +} + +def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton::FuncOp"> { + let summary = "Convert multi-dimensional reductions."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createLowerMultiReductionPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + +def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations to LLVM."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + +def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton debug operations (prints and asserts) to LLVM."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createDebugOpsToLLVMPass()"; + + let dependentDialects = ["mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + +def MathToVecLib : Pass<"triton-cpu-math-to-vec-lib", "mlir::ModuleOp"> { + let summary = "Convert vector math operations to vector libm or sleef calls."; + let description = [{ + }]; + let constructor = "mlir::triton::cpu::createMathToVecLibPass()"; + + let options = [ + Option<"lib", "lib", + "mlir::triton::cpu::VecLib", /*default*/"mlir::triton::cpu::VecLib::Sleef", + "Library to use for vector math (libsleef or libmvec).">, + ]; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect", + "mlir::func::FuncDialect", + "mlir::LLVM::LLVMDialect"]; +} + +#endif diff --git a/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..cb2cb234172d --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUTransforms) +add_public_tablegen_target(TritonCPUTransformsPassIncGen) diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h new file mode 100644 index 000000000000..c3fe3973ce0b --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -0,0 +1,185 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H +#define TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace triton { +namespace cpu { + +inline Type getElemTyOrTy(Type ty) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.getElementType(); + return ty; +} + +inline bool isTyOrVectorOf(Type ty, Type elemTy) { + return getElemTyOrTy(ty) == elemTy; +} + +inline bool isBf16(Type ty) { + return isTyOrVectorOf(ty, BFloat16Type::get(ty.getContext())); +} + +inline bool isFp16(Type ty) { + return isTyOrVectorOf(ty, Float16Type::get(ty.getContext())); +} + +inline bool isFp32(Type ty) { + return isTyOrVectorOf(ty, Float32Type::get(ty.getContext())); +} + +inline bool isFp8(Type ty) { + Type elemTy = getElemTyOrTy(ty); + if (elemTy.isIntOrFloat() && !elemTy.isInteger()) + return elemTy.getIntOrFloatBitWidth() == 8; + return false; +} + +inline Type toTyOrVectorOf(Type ty, Type elemTy) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.cloneWith(std::nullopt, elemTy); + return elemTy; +} + +inline Type toInt8(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 8)); +} + +inline Type toInt16(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 16)); +} + +inline Type toInt32(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 32)); +} + +inline Type toInt64(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 64)); +} + +inline Type toFp8E5M2(Type ty) { + return toTyOrVectorOf(ty, Float8E5M2Type::get(ty.getContext())); +} + +inline Type toFp16(Type ty) { + return toTyOrVectorOf(ty, Float16Type::get(ty.getContext())); +} + +inline Type toBf16(Type ty) { + return toTyOrVectorOf(ty, BFloat16Type::get(ty.getContext())); +} + +inline Type toFp32(Type ty) { + return toTyOrVectorOf(ty, Float32Type::get(ty.getContext())); +} + +inline Value intCst(Location loc, Type ty, int64_t val, + PatternRewriter &rewriter) { + TypedAttr valAttr = IntegerAttr::get(getElemTyOrTy(ty), val); + if (auto vecTy = dyn_cast(ty)) + valAttr = SplatElementsAttr::get(vecTy, valAttr); + return rewriter.create(loc, valAttr); +} + +inline Value fpCst(Location loc, Type ty, double val, + PatternRewriter &rewriter) { + TypedAttr valAttr = FloatAttr::get(getElemTyOrTy(ty), val); + if (auto vecTy = dyn_cast(ty)) + valAttr = SplatElementsAttr::get(vecTy, valAttr); + return rewriter.create(loc, valAttr); +} + +template ::value, bool> = true> +Value cstLike(Location loc, Value tySrc, T val, PatternRewriter &rewriter) { + return intCst(loc, tySrc.getType(), val, rewriter); +} + +template ::value, bool> = true> +Value cstLike(Location loc, Value tySrc, T val, PatternRewriter &rewriter) { + return fpCst(loc, tySrc.getType(), val, rewriter); +} + +inline Value shapeCast(Location loc, Value in, VectorType outTy, + PatternRewriter &rewriter) { + VectorType inTy = cast(in.getType()); + assert(outTy.getElementType() == inTy.getElementType()); + assert(outTy.getNumElements() == inTy.getNumElements()); + return rewriter.create(loc, outTy, in); +} + +inline Value shapeCast(Location loc, Value in, + std::initializer_list shapes, + PatternRewriter &rewriter) { + VectorType inTy = cast(in.getType()); + VectorType outTy = VectorType::get(shapes, inTy.getElementType()); + return shapeCast(loc, in, outTy, rewriter); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#define int_cst(ty, val) intCst(loc, ty, val, rewriter) +#define index_cst(val) rewriter.create(loc, val) +#define cst_like(src, val) cstLike(loc, src, val, rewriter) + +#define op_addi(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_addf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_subi(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_subf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_muli(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_mulf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_bitcast(ty, val) rewriter.create(loc, ty, val) +#define op_lshr(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_shl(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_trunci(ty, val) rewriter.create(loc, ty, val) +#define op_zext(ty, val) rewriter.create(loc, ty, val) +#define op_sext(ty, val) rewriter.create(loc, ty, val) +#define op_and(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_or(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_minui(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_maxui(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_select(cond, val, other) \ + rewriter.create(loc, cond, val, other) +#define op_sitofp(ty, val) rewriter.create(loc, ty, val) +#define op_fptosi(ty, val) rewriter.create(loc, ty, val) +#define op_read(ty, memRef, indices) \ + rewriter.create( \ + loc, ty, memRef, indices, SmallVector(ty.getRank(), true)) +#define op_write(val, memRef, indices) \ + rewriter.create( \ + loc, val, memRef, indices, \ + SmallVector(cast(val.getType()).getRank(), true)) +#define op_interleave(lhs, rhs) \ + rewriter.create(loc, lhs, rhs) +#define op_extract(vec, idx) rewriter.create(loc, vec, idx) +#define op_store(val, mem, idx) \ + rewriter.create(loc, val, mem, idx) + +#define op_icmp_eq(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::eq, lhs, rhs) +#define op_icmp_ne(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ne, lhs, rhs) +#define op_icmp_ugt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ugt, lhs, rhs) +#define op_icmp_uge(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::uge, lhs, rhs) +#define op_icmp_ult(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ult, lhs, rhs) +#define op_icmp_ule(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ule, lhs, rhs) +#define op_icmp_sgt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sgt, lhs, rhs) +#define op_icmp_sge(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sge, lhs, rhs) +#define op_icmp_slt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::slt, lhs, rhs) +#define op_icmp_sle(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sle, lhs, rhs) + +#endif diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h new file mode 100644 index 000000000000..f0c7a777e5fa --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -0,0 +1,51 @@ +#ifndef TritonCPUTransforms_CONVERSION_PASSES_H +#define TritonCPUTransforms_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +std::unique_ptr> createConvertUnsupportedOps(); +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32); +std::unique_ptr> createDecomposeFpConversions(); +std::unique_ptr> +createDecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions); +std::unique_ptr> createOptimizeMasks(); + +std::unique_ptr> createConvertDotProduct(); +std::unique_ptr> +createConvertDotProduct(bool useHorizontalSum); + +std::unique_ptr> createConvertDotToAMX(); +std::unique_ptr> +createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16); +std::unique_ptr> createConvertDotToFMA(); +std::unique_ptr> createConvertDotGeneric(); +std::unique_ptr> createCanonicalize(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td new file mode 100644 index 000000000000..00c01a4725ce --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -0,0 +1,165 @@ +#ifndef TRITONCPUOPT_CONVERSION_PASSES +#define TRITONCPUOPT_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "mlir::ModuleOp"> { + let summary = "Convert operations on unsupported types."; + let description = [{ + This pass converts various operations on data types that are not supported + by the target natively. Operations are converted to a supported data type + with casts added for inputs and the result. + }]; + + let options = [ + Option<"promoteBf16ToFp32", "promote-bf16-to-fp32", + "bool", /*default*/"false", + "Convert BF16 operations to FP32.">, + Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul", + "bool", /*default*/"false", + "Convert inputs of a mixed-precision matmul to a destination type.">, + Option<"promoteLibMathToFp32", "promote-lib-math-to-fp32", + "bool", /*default*/"true", + "Promote FP8, FP16, BF16 math operations mapped to libm function to FP32.">, + ]; + + let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def DecomposeFpConversions : Pass<"triton-cpu-decompose-fp-conversions", "mlir::ModuleOp"> { + let summary = "Decompose fp conversion ops."; + let description = [{ + This pass is used for targets lacking native instructions to convert FP + vectors. By default, LLVM would decompose them using scalar FP conversion + intrinsics. This pass transforms such conversions into vector code + instead. + }]; + + let options = [ + Option<"decomposeBf16Conversions", "decompose-bf16-conversions", + "bool", /*default*/"false", + "Lower BF16 conversions to arith operations.">, + Option<"decomposeFp8Conversions", "decompose-fp8-conversions", + "bool", /*default*/"false", + "Lower FP8 conversions to arith operations.">, + ]; + + let constructor = "mlir::triton::cpu::createDecomposeFpConversions()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def OptimizeMasks : Pass<"triton-cpu-optimize-masks", "mlir::ModuleOp"> { + let summary = "Optimize masked memory accesses."; + let description = [{ + This pass tries to detect masked memory accesses with mask values that + can be proven to be all-ones or all-zeros. + }]; + + let options = [ + ]; + + let constructor = "mlir::triton::cpu::createOptimizeMasks()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotProduct : Pass<"triton-cpu-convert-dot-product", "mlir::ModuleOp"> { + let summary = "Convert dot product op."; + let description = [{ + This pass is used for indentifying dot product pattern + (for example, elementwise mul followed by a sum) and + converting it to dot product intrinsics like bfdot. + }]; + + let options = [ + Option<"useHorizontalSum", "use-horizontal-sum", + "bool", /*default*/"true", + "Use Horizontal Sum kernel for the dot product (gemv). Otherwise use a kernel with packing.">, + ]; + + let constructor = "mlir::triton::cpu::createConvertDotProduct()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotToAMX : Pass<"triton-cpu-convert-dot-to-amx", "mlir::ModuleOp"> { + let summary = "Convert dot product op to AMX dialect."; + let description = [{ + This pass is used to lower matmul operations to amx dialect. + }]; + + let options = [ + Option<"convertInt8", "convert-i8", + "bool", /*default*/"false", + "Use AMX extensions for int8 type.">, + Option<"convertFp16", "convert-fp16", + "bool", /*default*/"false", + "Use AMX extensions for ifp16 type.">, + Option<"convertBf16", "convert-bf16", + "bool", /*default*/"false", + "Use AMX extensions for bf16 type.">, + ]; + + let constructor = "mlir::triton::cpu::createConvertDotToAMX()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::amx::AMXDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotToFMA : Pass<"triton-cpu-convert-dot-to-fma", "mlir::ModuleOp"> { + let summary = "Decompose dot product op to a series of FMA operations."; + let description = [{ }]; + + let constructor = "mlir::triton::cpu::createConvertDotToFMA()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotGeneric : Pass<"triton-cpu-convert-dot-generic", "mlir::ModuleOp"> { + let summary = "Generic convertion of dot product op."; + let description = [{ + This pass is used to lower matmul operations to generic vector code. + }]; + + let constructor = "mlir::triton::cpu::createConvertDotGeneric()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def Canonicalize : Pass<"triton-cpu-canonicalize", "mlir::ModuleOp"> { + let summary = "Canonicalization pass."; + let description = [{ + This pass applies various foldings to simplify analysis and transformations + in optimization passes. + }]; + + let constructor = "mlir::triton::cpu::createCanonicalize()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/include/TritonRaiseBlockPointer/CMakeLists.txt b/third_party/cpu/include/TritonRaiseBlockPointer/CMakeLists.txt new file mode 100644 index 000000000000..86bab632f114 --- /dev/null +++ b/third_party/cpu/include/TritonRaiseBlockPointer/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonRaiseBlockPointer) +add_public_tablegen_target(TritonRaiseBlockPointerPassIncGen) diff --git a/third_party/cpu/include/TritonRaiseBlockPointer/Passes.h b/third_party/cpu/include/TritonRaiseBlockPointer/Passes.h new file mode 100644 index 000000000000..06d2d7b4fc8a --- /dev/null +++ b/third_party/cpu/include/TritonRaiseBlockPointer/Passes.h @@ -0,0 +1,25 @@ +//===- Passes.h - Triton to Block Pointer Pass ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_RAISE_BLOCK_POINTER_PASSES_H +#define TRITON_RAISE_BLOCK_POINTER_PASSES_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::cpu { +#define GEN_PASS_DECL +#include "cpu/include/TritonRaiseBlockPointer/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonRaiseBlockPointer/Passes.h.inc" +} // namespace mlir::triton::intel + +#endif // TRITON_RAISE_BLOCK_POINTER_PASSES_H diff --git a/third_party/cpu/include/TritonRaiseBlockPointer/Passes.td b/third_party/cpu/include/TritonRaiseBlockPointer/Passes.td new file mode 100644 index 000000000000..c8e429672a27 --- /dev/null +++ b/third_party/cpu/include/TritonRaiseBlockPointer/Passes.td @@ -0,0 +1,30 @@ +//===-- Passes.td - Triton to Block Pointer Passes ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_RAISE_BLOCK_POINTER_PASSES +#define TRITON_RAISE_BLOCK_POINTER_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonRaiseBlockPointer + : Pass<"triton-cpu-raise-block-pointer", "mlir::ModuleOp"> { + let summary = "Convert Triton non-block pointer to block pointer"; + let description = [{ + Pass to raise different patterns operating on pointers to block pointers. + The basic idea of this pass is to convert code operating on tensors of + pointers (`tensor<...x!tt.ptr>`) to pointers to tensors + (`!tt.ptr>`) using `tt.make_block_ptr` operations. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + ]; +} + +#endif // TRITON_RAISE_BLOCK_POINTER_PASSES diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..56e231273ed6 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) +add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h new file mode 100644 index 000000000000..cd0babee3de4 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -0,0 +1,128 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H +#define TRITONTOTRITONCPU_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/AxisInfo.h" +#include "llvm/ADT/TypeSwitch.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertElemManipOps(); +std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> +createConvertMemoryOps(bool useGatherScatter); +std::unique_ptr> createConvertPtrOps(); +std::unique_ptr> createConvertDotOp(); +std::unique_ptr> createConvertControlFlowOps(); +std::unique_ptr> createConvertHistogramOp(); +std::unique_ptr> createConvertReductionOp(); +std::unique_ptr> +createConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp); +std::unique_ptr> createConvertScanOp(); +std::unique_ptr> createConvertAtomicOps(); +std::unique_ptr> createConvertDebugOps(); + +std::unique_ptr> createScalarizeUsingForOpPass(); +std::unique_ptr> +createScalarizeUsingForOpPass(bool skipGatherScatter); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +template +constexpr bool is_one_of_v = (std::is_same_v || ...); + +template +constexpr bool is_memory_op_v = + is_one_of_v; + +inline mlir::Type getMemoryOpType(triton::LoadOp operation) { + return operation.getType(); +} + +inline mlir::Type getMemoryOpType(triton::StoreOp operation) { + return operation.getValue().getType(); +} + +inline ArrayRef getShape(mlir::Type type) { + return llvm::TypeSwitch>(type) + .Case([](ShapedType t) { return t.getShape(); }) + .Case([](RankedTensorType t) { return t.getShape(); }) + .Default([](Type t) { + llvm::errs() << "Attempt to getShape from unknow type: " << t << "\n"; + llvm_unreachable("Unsupported type in getShape"); + return ArrayRef(); + }); +} + +inline bool hasShape(mlir::Type type) { + return isa(type); +} + +template , bool> = true> +bool isContiguousRowMajorAccess(AxisInfo *axisInfo, OpTy op) { + if (!axisInfo) + return false; + + mlir::Type type = getMemoryOpType(op); + if (!hasShape(type)) { + return false; + } + auto shape = getShape(type); + auto contiguity = axisInfo->getContiguity(); + return (shape.back() > 1 && shape.back() == contiguity.back()); +} + +// Get the base pointer and offset of a memory operation if the pointer is +// defined by a SplatOp and an AddPtrOp. +template , bool> = true> +std::tuple getMemoryBaseOffset(OpTy op) { + Value ptr = op.getPtr(); + + auto addPtrOp = ptr.getDefiningOp(); + if (!addPtrOp) + return std::make_tuple(nullptr, nullptr); + + Value basePtr = nullptr; + Value offset = nullptr; + + if (auto splatOp = addPtrOp->getOperand(0).getDefiningOp()) { + if (isa(splatOp.getOperand().getType())) { + basePtr = splatOp.getOperand(); + offset = addPtrOp.getOperand(1); + } + } + + if (auto splatOp = addPtrOp->getOperand(1).getDefiningOp()) { + if (!basePtr && isa(splatOp.getOperand().getType())) { + basePtr = splatOp.getOperand(); + offset = addPtrOp.getOperand(0); + } + } + + return std::make_tuple(basePtr, offset); +} + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td new file mode 100644 index 000000000000..8def195cc220 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -0,0 +1,197 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES +#define TRITONTOTRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton memory ops."; + let description = [{ + + }]; + + let options = [ + Option<"useGatherScatter", "use-gather-scatter", + "bool", /*default*/"false", + "Use Gather or Scatter to lower memory ops.">, + ]; + + let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { + let summary = "Convert elementwise ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertElemManipOps : Pass<"triton-cpu-convert-elem-manip-ops", "mlir::ModuleOp"> { + let summary = "Convert elements manipulation ops (transpose, shuffle, etc.)."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElemManipOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton ops related to pointer arithmetics."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertPtrOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDotOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertControlFlowOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertHistogramOp : Pass<"triton-cpu-convert-histogram-op", "mlir::ModuleOp"> { + let summary = "Convert Triton HistogramOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertHistogramOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect"]; +} + +def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> { + let summary = "Convert Triton ReduceOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertReductionOp()"; + + let options = [ + Option<"useMultiDimReductionOp", "use-multidim-reduction-op", + "bool", /*default*/"false", + "Use vector::MultiDimReductionOp and its default lowering when possible.">, + + Option<"useReductionOp", "use-reduction-op", + "bool", /*default*/"false", + "Use vector::ReductionOp and its default lowering when possible.">, + ]; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> { + let summary = "Convert Triton ScanOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertScanOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertAtomicOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDebugOps : Pass<"triton-cpu-convert-debug-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton debug operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDebugOps()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ScalarizeUsingForOp : Pass<"triton-cpu-scalarize", "mlir::ModuleOp"> { + let summary = "Insert Loops for ops, that are not vectorizable"; + let description = [{ + This pass is used to reduce compile time by generating loops for + operations that cannot be handled as vectors, and simply increases + the amount of IR without any further optimization. + }]; + + let options = [ + Option<"skipGatherScatter", "skip-gather-scatter", + "bool", /*default*/"false", + "Skip scalarizing gather/scatter ops.">, + ]; + + let constructor = "mlir::triton::cpu::createScalarizeUsingForOpPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + + +#endif diff --git a/third_party/cpu/include/Xsmm/CMakeLists.txt b/third_party/cpu/include/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..ede68918f9a6 --- /dev/null +++ b/third_party/cpu/include/Xsmm/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUXsmm) +add_public_tablegen_target(TritonCPUXsmmPassIncGen) + +set(LLVM_TARGET_DEFINITIONS XsmmEnum.td) +mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls) +mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonCPUXsmmAttrDefIncGen) diff --git a/third_party/cpu/include/Xsmm/Passes.h b/third_party/cpu/include/Xsmm/Passes.h new file mode 100644 index 000000000000..2195a2feb1ec --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.h @@ -0,0 +1,75 @@ +#ifndef TritonCPUXsmm_CONVERSION_PASSES_H +#define TritonCPUXsmm_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; + +namespace affine { +class AffineDialect; +} // namespace affine + +namespace arith { +class ArithDialect; +} // namespace arith + +namespace func { +class FuncOp; +class FuncDialect; +} // namespace func + +namespace linalg { +class LinalgDialect; +} // namespace linalg + +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + +namespace math { +class MathDialect; +} // namespace math + +namespace memref { +class MemRefDialect; +} // namespace memref + +namespace scf { +class SCFDialect; +} // namespace scf + +namespace tensor { +class TensorDialect; +} // namespace tensor + +namespace vector { +class VectorDialect; +} // namespace vector + +namespace triton { +class TritonDialect; + +namespace cpu { +class TritonCPUDialect; +} // namespace cpu +} // namespace triton + +} // namespace mlir + +namespace mlir { +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/Xsmm/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "cpu/include/Xsmm/Passes.h.inc" + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/Xsmm/Passes.td b/third_party/cpu/include/Xsmm/Passes.td new file mode 100644 index 000000000000..08ecdf76c3cf --- /dev/null +++ b/third_party/cpu/include/Xsmm/Passes.td @@ -0,0 +1,42 @@ +#ifndef TRITONCPU_XSMM_PASSES +#define TRITONCPU_XSMM_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::ModuleOp"> { + let summary = "Convert vector to xsmm"; + let description = [{ + Convert vector operations to XSMM operations. + }]; + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", + "memref::MemRefDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect"]; +} + +def ConvertTritonToXsmm : Pass<"triton-cpu-convert-triton-to-xsmm", "mlir::ModuleOp"> { + let summary = "Convert triton to xsmm"; + let description = [{ + Convert triton operations to XSMM operations. + }]; + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", + "memref::MemRefDialect", + "triton::cpu::TritonCPUDialect", + "LLVM::LLVMDialect"]; +} + +def LoopToBrgemmXsmm : Pass<"triton-cpu-loop-to-brgemm-xsmm", "mlir::ModuleOp"> { + let summary = "Redution loop GEMM to BRGEMM"; + let description = [{ + Collapse reduction loop over GEMM to XSMM BRGEMM kernel. + }]; + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", + "memref::MemRefDialect", + "triton::cpu::TritonCPUDialect", + "LLVM::LLVMDialect"]; +} + +#endif diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.h b/third_party/cpu/include/Xsmm/XsmmEnum.h new file mode 100644 index 000000000000..19bfad8b16ba --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.h @@ -0,0 +1,18 @@ +//===- XsmmEnum.h -----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMENUM_H +#define TPP_DIALECT_XSMM_XSMMENUM_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "cpu/include/Xsmm/XsmmEnum.h.inc" + +#endif // TPP_DIALECT_XSMM_XSMMENUM_H diff --git a/third_party/cpu/include/Xsmm/XsmmEnum.td b/third_party/cpu/include/Xsmm/XsmmEnum.td new file mode 100644 index 000000000000..d1edae6d0beb --- /dev/null +++ b/third_party/cpu/include/Xsmm/XsmmEnum.td @@ -0,0 +1,84 @@ +//===- XsmmEnum --------------------------------------------*- Tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" + +def Xsmm_DataType: I64EnumAttr< + "DataType", "see: libxsmm_datatype", + [ + I64EnumAttrCase<"F32", 1, "f32">, + I64EnumAttrCase<"BF16", 2, "bf16">, + I64EnumAttrCase<"BF8", 4, "bf8"> + ]>{ + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryKind : I64EnumAttr< + "BinaryKind", "see: libxsmm_meltw_binary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"ADD", 1, "add">, + I64EnumAttrCase<"MUL", 2, "mul">, + I64EnumAttrCase<"SUB", 3, "sub">, + I64EnumAttrCase<"DIV", 4, "div"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryKind : I64EnumAttr< + "UnaryKind", "see: libxsmm_meltw_unary_type", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"IDENTITY", 1, "identity">, + I64EnumAttrCase<"ZERO", 2, "zero">, + I64EnumAttrCase<"RELU", 5, "relu">, + I64EnumAttrCase<"VNNI2", 28, "vnni_2">, + I64EnumAttrCase<"TRANSPOSE", 29, "transpose"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_UnaryFlags : I64EnumAttr< + "UnaryFlags", "see: libxsmm_meltw_unary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW", 2, "bcast_row">, + I64EnumAttrCase<"BCAST_COL", 4, "bcast_col">, + I64EnumAttrCase<"BCAST_SCALAR", 8, "bcast_scalar"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_BinaryFlags : I64EnumAttr< + "BinaryFlags", "see: libxsmm_meltw_binary_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BCAST_ROW_IN_0", 1, "bcast_row_in0">, + I64EnumAttrCase<"BCAST_ROW_IN_1", 2, "bcast_row_in1">, + I64EnumAttrCase<"BCAST_COL_IN_0", 4, "bcast_col_in0">, + I64EnumAttrCase<"BCAST_COL_IN_1", 8, "bcast_col_in1">, + I64EnumAttrCase<"BCAST_SCALAR_IN_0", 16, "bcast_scalar_in0">, + I64EnumAttrCase<"BCAST_SCALAR_IN_1", 32, "bcast_scalar_in1"> + ]> { + let cppNamespace = "mlir::xsmm"; +} + +def Xsmm_GemmFlags : I64EnumAttr< + "GemmFlags", "see: libxsmm_gemm_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BETA_0", 4, "beta_0">, + I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">, + I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">, + I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">, + I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">, + I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig"> + ]> { + let cppNamespace = "mlir::xsmm"; +} diff --git a/third_party/cpu/language/cpu/__init__.py b/third_party/cpu/language/cpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/third_party/cpu/language/cpu/libdevice.py b/third_party/cpu/language/cpu/libdevice.py new file mode 100644 index 000000000000..438f49cacf51 --- /dev/null +++ b/third_party/cpu/language/cpu/libdevice.py @@ -0,0 +1,222 @@ +import triton.language as tl +from triton.language import core +from triton.language.core import builtin +from triton import jit + + +@core.extern +def acos(arg0, _builder=None): + return core.tensor(_builder.create_acos(arg0.handle), arg0.type) + + +@core.extern +def acosh(arg0, _builder=None): + return core.tensor(_builder.create_acosh(arg0.handle), arg0.type) + + +@core.extern +def asin(arg0, _builder=None): + return core.tensor(_builder.create_asin(arg0.handle), arg0.type) + + +@core.extern +def asinh(arg0, _builder=None): + return core.tensor(_builder.create_asinh(arg0.handle), arg0.type) + + +@core.extern +def atan(arg0, _builder=None): + return core.tensor(_builder.create_atan(arg0.handle), arg0.type) + + +@core.extern +def atanh(arg0, _builder=None): + return core.tensor(_builder.create_atanh(arg0.handle), arg0.type) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.tensor(_builder.create_cbrt(arg0.handle), arg0.type) + + +@core.extern +def cos(arg0, _builder=None): + return core.tensor(_builder.create_cos(arg0.handle), arg0.type) + + +@core.extern +def cosh(arg0, _builder=None): + return core.tensor(_builder.create_cosh(arg0.handle), arg0.type) + + +@core.extern +def erf(arg0, _builder=None): + return core.tensor(_builder.create_erf(arg0.handle), arg0.type) + + +@core.extern +def exp(arg0, _builder=None): + return core.tensor(_builder.create_exp(arg0.handle), arg0.type) + + +@core.extern +def exp2(arg0, _builder=None): + return core.tensor(_builder.create_exp2(arg0.handle), arg0.type) + + +@core.extern +def expm1(arg0, _builder=None): + return core.tensor(_builder.create_expm1(arg0.handle), arg0.type) + + +@core.extern +def floor(arg0, _builder=None): + return core.tensor(_builder.create_floor(arg0.handle), arg0.type) + + +@core.extern +def log(arg0, _builder=None): + return core.tensor(_builder.create_log(arg0.handle), arg0.type) + + +@core.extern +def log2(arg0, _builder=None): + return core.tensor(_builder.create_log2(arg0.handle), arg0.type) + + +@core.extern +def log10(arg0, _builder=None): + return core.tensor(_builder.create_log10(arg0.handle), arg0.type) + + +@core.extern +def log1p(arg0, _builder=None): + return core.tensor(_builder.create_log1p(arg0.handle), arg0.type) + + +@core.extern +def sin(arg0, _builder=None): + return core.tensor(_builder.create_sin(arg0.handle), arg0.type) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.tensor(_builder.create_rsqrt(arg0.handle), arg0.type) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.tensor(_builder.create_sqrt(arg0.handle), arg0.type) + + +@core.extern +def sinh(arg0, _builder=None): + return core.tensor(_builder.create_sinh(arg0.handle), arg0.type) + + +@core.extern +def tan(arg0, _builder=None): + return core.tensor(_builder.create_tan(arg0.handle), arg0.type) + + +@core.extern +def tanh(arg0, _builder=None): + return core.tensor(_builder.create_tanh(arg0.handle), arg0.type) + + +@core.extern +def trunc(arg0, _builder=None): + return core.tensor(_builder.create_trunc(arg0.handle), arg0.type) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Sleef_ceilf%(numel)", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Sleef_ceild%(numel)", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Sleef_powf%(numel)_u10", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Sleef_powd%(numel)_u10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Sleef_fmodf%(numel)", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Sleef_fmodd%(numel)", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@jit +def _const(v, dtype): + """ + Create a tensor with a single value of type :dtype. + """ + return tl.full((1, ), v, dtype) + + +@jit +def _is_special_float(arg0, uint_dtype, kind: tl.constexpr): + # By default, Triton assumes constexprs are int32. Thus, when we do operations with constants, + # we end up auto-promoting smaller integer types to int32, which is undesirable. Thus we + # explicitly cast them to our desired type here. + one = _const(1, uint_dtype) + zero = _const(0, uint_dtype) + + bitwidth: tl.constexpr = arg0.dtype.primitive_bitwidth + exponent_width: tl.constexpr = bitwidth - 1 - arg0.dtype.fp_mantissa_width + mantissa_width: tl.constexpr = arg0.dtype.fp_mantissa_width + + uintval = arg0.to(uint_dtype, bitcast=True) + exponent = uintval << one >> _const(mantissa_width, uint_dtype) + one + exp_is_all_ones = exponent == (one << _const(exponent_width, uint_dtype)) - one + shifted_mantissa = uintval << _const(exponent_width, uint_dtype) + one + + if kind == "nan": + return exp_is_all_ones & (shifted_mantissa != zero) + elif kind == "inf": + return exp_is_all_ones & (shifted_mantissa == zero) + else: + raise ValueError(f"Unexpected kind {kind}") + + +@builtin +def isnan(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("isnan expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_is_special_float, (arg0, uint_dtype, "nan"), kwargs={}) + + +@builtin +def isinf(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("isinf expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_is_special_float, (arg0, uint_dtype, "inf"), kwargs={}) + + +@jit +def _signbit(arg0, uint_dtype: tl.constexpr): + bitwidth: tl.constexpr = arg0.dtype.primitive_bitwidth + return arg0.to(uint_dtype, bitcast=True) >> (bitwidth - 1) + + +@builtin +def signbit(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("signbit expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_signbit, (arg0, uint_dtype), kwargs={}) diff --git a/third_party/cpu/lib/Analysis/CMakeLists.txt b/third_party/cpu/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000000..d0ac08b9daf0 --- /dev/null +++ b/third_party/cpu/lib/Analysis/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonCPUAnalysis + TensorPtrShapeInfo.cpp + + DEPENDS + TritonCPUTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + TritonIR + TritonCPUIR +) diff --git a/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp new file mode 100644 index 000000000000..bd3959e051f0 --- /dev/null +++ b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp @@ -0,0 +1,219 @@ +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::cpu { + +TensorPtrShapeInfo TensorPtrShapeInfo::join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + assert(lhs.getRank() == rhs.getRank()); + + SmallVector shape(lhs.getShape()); + SmallVector strides(lhs.getStrides()); + for (int64_t i = 0; i < lhs.getRank(); ++i) { + if (shape[i] != rhs.getSize(i)) + shape[i] = ShapedType::kDynamic; + if (strides[i] != rhs.getStride(i)) + strides[i] = ShapedType::kDynamic; + } + return TensorPtrShapeInfo(shape, strides); +} + +namespace { + +template +void initPessimisticStateFromFunc(int argNumber, T funcOp, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + auto loadFromAttr = [&](std::string_view attrName, + SmallVectorImpl &out) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + out = SmallVector(vals.begin(), vals.end()); + } + }; + loadFromAttr("tt.shape", shape); + loadFromAttr("tt.strides", strides); +} + +TensorPtrShapeInfo getPessimisticValueState(Value value) { + int rank = 0; + if (triton::isTensorPointerType(value.getType())) + rank = cast(getPointeeType(value.getType())).getRank(); + + SmallVector shape; + SmallVector strides; + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state. + } else { + // Other operations are conservatively initialized with dynamic + // shape and strides unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.shape")) { + auto vals = cast(attr).getValues(); + shape = SmallVector(vals.begin(), vals.end()); + } else { + shape.insert(shape.end(), rank, ShapedType::kDynamic); + } + if (Attribute attr = op->getDiscardableAttr("tt.strides")) { + auto vals = cast(attr).getValues(); + strides = SmallVector(vals.begin(), vals.end()); + } else { + strides.insert(strides.end(), rank, ShapedType::kDynamic); + } + } + } + + return TensorPtrShapeInfo(shape, strides); +} + +class ShapeInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + void + setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join(getPessimisticValueState(lattice->getAnchor()))); + } + +public: + ShapeInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncShapeInfoMapT = DenseMap; + + LogicalResult visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +ShapeInfoAnalysis::ShapeInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>(solver) {} + +SmallVector copyConstOrDynamic(OperandRange ops) { + SmallVector res; + for (auto op : ops) { + if (auto cstOp = op.getDefiningOp()) { + auto intAttr = dyn_cast(cstOp.getValue()); + assert(intAttr); + res.push_back(intAttr.getInt()); + } else { + res.push_back(ShapedType::kDynamic); + } + } + return res; +} + +LogicalResult ShapeInfoAnalysis::visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + + TensorPtrShapeInfo res; + // Tensor pointers are only produced by MakeTensorPtrOp which has + // shape and strides as its args, and AdvanceOp which preserves + // shape and strides of the input pointer. + if (auto makePtrOp = dyn_cast(op)) { + SmallVector shape = copyConstOrDynamic(makePtrOp.getShape()); + SmallVector strides = copyConstOrDynamic(makePtrOp.getStrides()); + res = TensorPtrShapeInfo(shape, strides); + } else if (auto advOp = dyn_cast(op)) { + res = operands[0]->getValue(); + } + + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(res)); + + return success(); +} + +} // namespace + +void ModuleTensorPtrShapeInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + ShapeInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *shapeInfoMap = getFuncData(funcOp); + auto updateShapeInfoMap = [&](Value value) { + auto shapeInfo = analysis->getLatticeElement(value)->getValue(); + TensorPtrShapeInfo curShapeInfo; + if (shapeInfoMap->count(value)) { + curShapeInfo = + TensorPtrShapeInfo::join(shapeInfo, shapeInfoMap->lookup(value)); + } else { + curShapeInfo = shapeInfo; + } + (*shapeInfoMap)[value] = curShapeInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateShapeInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateShapeInfoMap(value); + } + }); +} + +void ModuleTensorPtrShapeInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *shapeInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, ArrayRef value) { + SmallVector curValue(value); + if (auto attr = + callee.getArgAttrOfType(index, attrName)) { + auto oldValue = cast(attr).getValues(); + assert(oldValue.size() == curValue.size()); + for (size_t i = 0; i < curValue.size(); ++i) + if (curValue[i] != oldValue[i]) + curValue[i] = ShapedType::kDynamic; + } + auto attr = DenseElementsAttr::get( + VectorType::get(curValue.size(), + IntegerType::get(callee.getContext(), 64)), + ArrayRef(curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto shapeInfo = shapeInfoMap->lookup(value); + if (shapeInfo.getRank()) { + setAttrFn("tt.shape", shapeInfo.getShape()); + setAttrFn("tt.strides", shapeInfo.getStrides()); + } + } +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt new file mode 100644 index 000000000000..a68251f38a5c --- /dev/null +++ b/third_party/cpu/lib/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(Analysis) +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) +add_subdirectory(TritonRaiseBlockPointer) +add_subdirectory(TritonToTritonCPU) +add_subdirectory(Xsmm) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp new file mode 100644 index 000000000000..9a2c183e1c4c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp @@ -0,0 +1,154 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ATOMICOPSTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +LLVM::AtomicOrdering getOrdering(MemSemantic sem) { + switch (sem) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + llvm_unreachable("Unexpected atomic mem semantic"); + } +} + +// TODO: use enums to access struct fields. +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opKind = getAtomicBinOp(op.getAtomicRmwOp(), op.getType()); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + rewriter.replaceOpWithNewOp(op, opKind, ptr, val, + ordering); + return success(); + } + + LLVM::AtomicBinOp getAtomicBinOp(RMWOp op, Type type) const { + switch (op) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::max + : LLVM::AtomicBinOp::fmax; + case RMWOp::MIN: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::min + : LLVM::AtomicBinOp::fmin; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + llvm_unreachable("Unexpected atomic op"); + } + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto cmp = rewriter.getRemappedValue(op.getCmp()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + auto failureOrdering = ordering != LLVM::AtomicOrdering::monotonic + ? LLVM::AtomicOrdering::acquire + : ordering; + Value cmpXchg = rewriter.create( + loc, ptr, cmp, val, ordering, failureOrdering); + Value oldVal = rewriter.create(loc, cmpXchg, 0); + rewriter.replaceOp(op, oldVal); + return success(); + } +}; + +struct AtomicOpsToLLVM + : public triton::impl::AtomicOpsToLLVMBase { + using AtomicOpsToLLVMBase::AtomicOpsToLLVMBase; + + AtomicOpsToLLVM() : AtomicOpsToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createAtomicOpsToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..5448d81937f4 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(TritonCPUToLLVM + AtomicOpsToLLVM.cpp + DebugOpsToLLVM.cpp + FuncOpToLLVM.cpp + GetProgramIdOpToLLVM.cpp + LowerMultiReduction.cpp + MathToVecLib.cpp + MemoryOpToLLVM.cpp + TypeConverter.cpp + Utility.cpp + + DEPENDS + TritonCPUToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRVectorToLLVMPass +) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp new file mode 100644 index 000000000000..33e1753e31b2 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -0,0 +1,374 @@ +#include "TypeConverter.h" +#include "Utility.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DEBUGOPSTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: This code is the same as the GPU-backend code. Consider refactoring. +std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt, + bool isSigned = false) { + Type type = value.getType(); + // If the `value` is a pointer, just return %p. + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + (isSigned ? "lli" : "llu"); + else + return prefix + (isSigned ? "i" : "u"); + } + assert(false && "not supported type"); + return ""; +} + +// For printf, need to extend int32 or float64. +Value printfPromoteValue(RewriterBase &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + auto loc = UnknownLoc::get(context); + + bool isUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + if (isUnsigned) { + return zext(ui32_ty, value); + } else { + return sext(i32_ty, value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + return fpext(f64_ty, value); + } + + return value; +} + +LLVM::LLVMFuncOp getOrAddPrintFuncDecl(ConversionPatternRewriter &rewriter, + bool printf) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = printf ? "printf" : "triton_vector_print"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType; + if (printf) + argsType = {ptr_ty(ctx)}; + else + argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx), ptr_ty(ctx), + i32_ty, i32_ty, i32_ty, i64_ty, i32_ty}; + + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ printf); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + +LLVM::LLVMFuncOp +getOrAddPrintMemrefFuncDecl(ConversionPatternRewriter &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = "triton_print_unranked_memref"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType; + + SmallVector elemTypes; + elemTypes.push_back(i64_ty); + elemTypes.push_back(ptr_ty(ctx)); + Type structTy = struct_ty(elemTypes); + + argsType = {/*pid serialization*/ i32_ty, + i32_ty, + i32_ty, /*end pids*/ + ptr_ty(ctx), + structTy, + /*type sreialization*/ i32_ty, + i32_ty, + i32_ty, /*end type*/ + i32_ty}; + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ false); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + +static StringRef makeNullTerminatedString(StringRef s) { + llvm::SmallString<64> ss(s); + ss.push_back(0); + return ss; +} + +void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, + std::array pid, StringRef prefix, + std::optional arg, bool hex = false, + bool isSigned = false) { + assert(!prefix.empty() && "printf with empty string not supported"); + auto loc = UnknownLoc::get(rewriter.getContext()); + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) + << ", " << getFormatSubstr(pid[2]) << ")" << prefix; + if (arg.has_value()) + os << getFormatSubstr(arg.value(), hex, std::nullopt, isSigned); + + llvm::SmallString<64> formatStrNewline(formatStr); + formatStrNewline.push_back('\n'); + formatStrNewline.push_back('\0'); + Value formatStrValue = + LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatStrNewline); + + SmallVector allArgs{formatStrValue}; + for (auto elem : pid) + allArgs.push_back(elem); + if (arg.has_value()) + allArgs.push_back(printfPromoteValue(rewriter, arg.value())); + call(getOrAddPrintFuncDecl(rewriter, true), allArgs); +} + +void createRuntimePrintCall(ConversionPatternRewriter &rewriter, + std::array pid, StringRef prefix, + Value ptr, Type dtype, bool isSigned, bool hex) { + assert(!prefix.empty()); + auto loc = UnknownLoc::get(rewriter.getContext()); + Value prefixValue = LLVM::addStringToModule( + loc, rewriter, "vectorPrintPrefix_", makeNullTerminatedString(prefix)); + + SmallVector allArgs; + for (auto elem : pid) + allArgs.push_back(elem); + + allArgs.push_back(prefixValue); + allArgs.push_back(ptr); + + allArgs.push_back(i32_val(dtype.getIntOrFloatBitWidth())); + allArgs.push_back(i32_val(dtype.isInteger())); + allArgs.push_back(i32_val(isSigned)); + allArgs.push_back(i32_val(hex)); + + call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs); +} + +bool usePrintf(triton::cpu::PrintOp op) { + // Simply use printf if no operand or the operand is scalar. + if (op.getNumOperands() == 0) + return true; + + // tt.print is already decomposed to triton_cpu.print per value. + assert(op.getNumOperands() == 1); + Type oprType = op.getOperands()[0].getType(); + return (oprType.isIntOrIndexOrFloat() || isa(oprType)); +} + +Value getPid(Operation *op, int axis) { + return getProgramId(op->getParentOfType(), axis); +}; + +struct PrintOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::cpu::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + std::array pid = {getPid(op, 0), getPid(op, 1), getPid(op, 2)}; + + if (usePrintf(op)) { + if (op.getNumOperands() == 0) { + createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), + std::nullopt); + } else { + createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), + adaptor.getOperands()[0], op.getHex(), + op.getIsSigned()[0]); + } + rewriter.eraseOp(op); + return success(); + } + + // TODO: support 2D+ vector printing. + std::string msg{op.getPrefix()}; + + createRuntimePrintCall( + rewriter, pid, op.getPrefix(), adaptor.getOperands()[0], + cast(op.getVal()[0].getType()).getElementType(), + op.getIsSigned()[0], op.getHex()); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct AssertOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::cpu::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + Value message = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", + makeNullTerminatedString(adaptor.getMessage())); + + // Based on lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp. + StringRef fileStr = "unknown"; + StringRef funcStr = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + fileStr = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + Value file = LLVM::addStringToModule(loc, rewriter, "assertFile_", + makeNullTerminatedString(fileStr)); + Value func = LLVM::addStringToModule(loc, rewriter, "assertFunc_", + makeNullTerminatedString(funcStr)); + SmallVector args{getPid(op, 0), getPid(op, 1), getPid(op, 2), + op.getCondition(), message, file, + i32_val(line), func}; + call(getAssertFuncDecl(rewriter), args); + rewriter.eraseOp(op); + return success(); + } + + static LLVM::LLVMFuncOp + getAssertFuncDecl(ConversionPatternRewriter &rewriter) { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = "triton_assert"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType{i32_ty, i32_ty, i32_ty, i1_ty, + ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx)}; + + auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); + } +}; + +using BarrierOp = mlir::gpu::BarrierOp; + +// This is part of the DebugOps pass because gpu::barrier is generated by +// tl.debug_barrier. +struct BarrierOpConversion : public ConvertOpToLLVMPattern { + explicit BarrierOpConversion(LLVMTypeConverter &typeConverter) + : mlir::ConvertOpToLLVMPattern(typeConverter) {} + + LogicalResult + matchAndRewrite(BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Just make it a no-op for now + rewriter.eraseOp(op); + return success(); + } +}; + +struct DebugOpsToLLVM + : public triton::impl::DebugOpsToLLVMBase { + using DebugOpsToLLVMBase::DebugOpsToLLVMBase; + + DebugOpsToLLVM() : DebugOpsToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +namespace mlir::triton::cpu { + +std::unique_ptr> createDebugOpsToLLVMPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000000..99962da6546a --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,293 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_FUNCOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + // 1. Modify the function type to add new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(ui32_ty); + amendedInputTy.push_back(ui32_ty); + amendedInputTy.push_back(ui32_ty); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add new arguments. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + SmallVector amendedArgAttrs; + if (funcOp.getAllArgAttrs()) { + amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back( + rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(), + rewriter.getArrayAttr(amendedArgAttrs))); + } + // 3. Add a new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + region.addArgument(ui32_ty, loc); + region.addArgument(ui32_ty, loc); + region.addArgument(ui32_ty, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto modifiedFuncOp = funcOp; + if (LLVM::isKernel(funcOp)) + modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + modifiedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) + return failure(); + + // required by AxisInfoAnalysis + if (LLVM::isKernel(funcOp)) + rewriter.eraseOp(modifiedFuncOp); + rewriter.eraseOp(funcOp); + return success(); + } +}; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto funcOp = op->getParentOfType(); + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = + insert_val(packedResultsTy, packedResults, it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { + using FuncOpToLLVMBase::FuncOpToLLVMBase; + + FuncOpToLLVM() : FuncOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + // Lower tt.func + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, + /*benefit=*/1); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + // Lower tt.call, tt.return + int benefit = 10; + RewritePatternSet patterns(context); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createFuncOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp new file mode 100644 index 000000000000..406b32cc7774 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -0,0 +1,110 @@ +#include "TypeConverter.h" +#include "Utility.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct GetProgramIdOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); + rewriter.replaceOp(op, getProgramId(funcOp, op.getAxisAsInt())); + return success(); + } +}; + +struct GetNumProgramsOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetNumProgramsOp"); + rewriter.replaceOp(op, getNumPrograms(funcOp, op.getAxisAsInt())); + return success(); + } +}; + +struct GetProgramIdOpToLLVM + : public triton::impl::GetProgramIdOpToLLVMBase { + using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; + + GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createGetProgramIdOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp new file mode 100644 index 000000000000..74f81cb0f9cc --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp @@ -0,0 +1,61 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_LOWERMULTIREDUCTION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This pass exists because LowerVectorMultiReductionPass can be run on +// func::FuncOp only and we translate triton::FuncOp directly into llvm::FuncOp. +// So we run the same set of patterns on triton::FuncOp. +struct LowerMultiReduction + : public mlir::triton::impl::LowerMultiReductionBase { + using LowerMultiReductionBase::LowerMultiReductionBase; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + + RewritePatternSet loweringPatterns(context); + // The default lowering option is InnerParallel + vector::VectorMultiReductionLowering options = + vector::VectorMultiReductionLowering::InnerReduction; + vector::populateVectorMultiReductionLoweringPatterns(loweringPatterns, + options); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createLowerMultiReductionPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp new file mode 100644 index 000000000000..2b1877c1c17b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -0,0 +1,462 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_MATHTOVECLIB +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template struct VecOpToFp32 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + VecOpToFp32(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isBF16() && !elemTy.isF16()) + return failure(); + + Type fp32VecTy = vecTy.cloneWith(std::nullopt, rewriter.getF32Type()); + SmallVector fp32Ops; + for (auto operand : op->getOperands()) + fp32Ops.push_back( + rewriter.create(loc, fp32VecTy, operand)); + auto newOp = rewriter.create(loc, fp32VecTy, fp32Ops); + rewriter.replaceOpWithNewOp(op, vecTy, newOp); + return success(); + } +}; + +// Decompose vector operation to single-dimensional vector operations +// with a AVX512 for x86 or NEON for ARM. +template +struct DecomposeToNativeVecs : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // CPU SIMD vector size in bits + size_t vec_bits; + + DecomposeToNativeVecs(MLIRContext *context, + size_t native_vec_size_in_bits = 512) + : OpRewritePattern(context), vec_bits(native_vec_size_in_bits) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isF64()) + return failure(); + + int64_t numElems = vecTy.getNumElements(); + if (numElems * elemTy.getIntOrFloatBitWidth() < 128) + return failure(); + + // Produce a new shape where trailing dimensions wouldn't exceed the native + // vector size. + auto shape = vecTy.getShape(); + SmallVector newShape(1, 1); + int64_t elemsPerVec = vec_bits / elemTy.getIntOrFloatBitWidth(); + for (int64_t i = shape.size() - 1; i >= 0; --i) { + int64_t size = shape[i]; + if (newShape.size() > 1) { + newShape.insert(newShape.begin(), size); + } else { + int64_t combined = newShape[0] * size; + if (combined > elemsPerVec) { + newShape[0] = elemsPerVec; + newShape.insert(newShape.begin(), combined / elemsPerVec); + } else { + newShape[0] = combined; + } + } + } + if (newShape == shape) + return failure(); + + // Convert input operand to the new shape. + SmallVector reshapedInputs; + for (auto operand : op->getOperands()) { + auto operandTy = cast(operand.getType()); + auto newOperandTy = VectorType::get(newShape, operandTy.getElementType()); + reshapedInputs.push_back( + rewriter.create(loc, newOperandTy, operand)); + } + + // Decompose the original operation to a set of operations on native + // vectors. + auto newOpTy = VectorType::get(newShape, elemTy); + auto subResTy = VectorType::get(newShape.back(), elemTy); + Value newRes = rewriter.create( + loc, SplatElementsAttr::get(newOpTy, rewriter.getFloatAttr(elemTy, 0))); + auto strides = computeStrides(newShape); + // Remove the last stride to produce sub-vector indices. + strides.pop_back(); + for (int64_t idx = 0; idx < numElems; idx += newShape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(reshapedInputs.size()); + std::transform(reshapedInputs.begin(), reshapedInputs.end(), + subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, + indices); + }); + Value subRes = + rewriter.create(loc, subResTy, subInputs, op->getAttrs()); + newRes = rewriter.create(loc, subRes, newRes, indices); + } + + // Reshape the result back to the original type. + rewriter.replaceOpWithNewOp(op, vecTy, newRes); + return success(); + } +}; + +using ExternElementwiseOp = triton::cpu::ExternElementwiseOp; + +/* + * libsleef does not contain implementations for 2-element vectors, so we pad + * any such vectors to size 4 instead. + */ +struct PadSmallVecsForSleef : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PadSmallVecsForSleef(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(ExternElementwiseOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isF64()) + return failure(); + + int64_t numElems = vecTy.getNumElements(); + if (numElems >= 4) + return failure(); + + // Create a single-element vector for shuffle to use + auto paddingVec = rewriter.create( + loc, undef(elemTy), VectorType::get({1}, elemTy)); + // Assign indices such that shuffle will pad the original vector with + // elements from the paddingVec + SmallVector indices(4); + for (int i = 0; i < 4; ++i) { + if (i < numElems) + indices[i] = i; + else + indices[i] = numElems; + } + SmallVector newOperands; + for (auto argVal : op.getOperands()) { + auto shuf = + rewriter.create(loc, argVal, paddingVec, indices); + newOperands.push_back(shuf.getResult()); + } + // Update return type of extern call + auto newVecTy = VectorType::get({4}, elemTy); + auto extern_elem = rewriter.create( + loc, newVecTy, newOperands, op.getSymbol(), op.getPure()); + indices.resize(numElems); + // Truncate result to original size + rewriter.replaceOpWithNewOp(op, extern_elem.getResult(), + paddingVec, indices); + return success(); + } +}; + +using GetVecFnNameFn = std::function; + +class MvecNameGenerator { +public: + explicit MvecNameGenerator(StringRef baseName) : baseName(baseName) {} + + std::string operator()(unsigned bitwidth, unsigned numel, + ValueRange operands) const { + if (bitwidth != 32 && bitwidth != 64) + return ""; + unsigned vecSize = numel * bitwidth; + std::string isaPrefix; + if (vecSize == 128) { + isaPrefix = "b"; + } else if (vecSize == 256) { + isaPrefix = "d"; + } else if (vecSize == 512) { + isaPrefix = "e"; + } else { + return ""; + } + std::string fnName = "_ZGV" + isaPrefix + "N" + std::to_string(numel); + for (auto operand : operands) + fnName += "v"; + return fnName + "_" + baseName + (bitwidth == 32 ? "f" : ""); + } + +private: + std::string baseName; +}; + +class SleefNameGenerator { +public: + SleefNameGenerator(StringRef baseName, unsigned ulp = 10) + : baseName(baseName), ulpSuffix(4, '\0') { + if (ulp == 0) { + ulpSuffix = ""; + } else { + char buf[5]; // 4 char suffix + '\0' added by snprintf + snprintf(buf, 5, "_u%02u", ulp); + ulpSuffix = buf; + } + } + + std::string operator()(unsigned bitwidth, unsigned numel, + ValueRange /*operands*/) const { + if (bitwidth != 32 && bitwidth != 64) + return ""; + unsigned vecSize = numel * bitwidth; + if (vecSize < 128) + return ""; + return "Sleef_" + baseName + (bitwidth == 32 ? "f" : "d") + + std::to_string(numel) + ulpSuffix; + } + +private: + std::string baseName; + std::string ulpSuffix; +}; + +template +struct OpToVecLibConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + virtual std::string getVecFnName(OpT op, unsigned bitwidth, + unsigned numel) const = 0; + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy || vecTy.getRank() > 1) + return failure(); + + auto fnName = getVecFnName(op, vecTy.getElementTypeBitWidth(), + vecTy.getNumElements()); + if (fnName.empty()) + return failure(); + + auto module = SymbolTable::getNearestSymbolTable(op); + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, fnName)); + // Generate function declaration if it doesn't exists yet. + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + auto fnTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = + rewriter.create(rewriter.getUnknownLoc(), fnName, fnTy); + opFunc.setPrivate(); + opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), + UnitAttr::get(rewriter.getContext())); + } + + rewriter.replaceOpWithNewOp(op, fnName, op.getType(), + op->getOperands()); + return success(); + } +}; + +template +struct VecOpToVecLibConversion : public OpToVecLibConversion { +public: + VecOpToVecLibConversion(MLIRContext *context, GetVecFnNameFn getVecFnName) + : OpToVecLibConversion(context), getVecFnNameImpl(getVecFnName) {} + + std::string getVecFnName(OpT op, unsigned bitwidth, + unsigned numel) const override { + return getVecFnNameImpl(bitwidth, numel, op->getOperands()); + } + +private: + GetVecFnNameFn getVecFnNameImpl; +}; + +struct ExternElementwiseOpConversion + : public OpToVecLibConversion { + using OpToVecLibConversion::OpToVecLibConversion; + + std::string getVecFnName(triton::cpu::ExternElementwiseOp op, + unsigned bitwidth, unsigned numel) const override { + auto fnName = op.getSymbol(); + auto numelIdx = fnName.find("%(numel)"); + if (numelIdx == StringRef::npos) + return fnName.str(); + return (fnName.take_front(numelIdx) + Twine(numel) + + fnName.drop_front(numelIdx + 8)) + .str(); + } +}; + +template +void populatePatternsForOp(RewritePatternSet &patterns, + GetVecFnNameFn getVecFnName, + size_t vec_size_in_bits = 512) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), + vec_size_in_bits); + patterns.add>(patterns.getContext(), + getVecFnName); +} + +struct MathToVecLibPass + : public mlir::triton::cpu::impl::MathToVecLibBase { + MathToVecLibPass() = default; + size_t vec_size_in_bits; + + explicit MathToVecLibPass(VecLib lib, std::set cpu_features) { + this->lib = lib; + update_vec_size(cpu_features); + } + + void update_vec_size(std::set &cpu_features) { + // TODO: + // Refactor this as an independent function. + // And improve this to support other x86 SIMD ISAs and also for arm SVE + // (VLA) + vec_size_in_bits = 512; + for (auto feature : cpu_features) { + // Arm NEON is fixed 128-bit SIMD ISA. + if (feature == "neon") { + vec_size_in_bits = 128; + break; + } + } + } + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + + RewritePatternSet patterns(context); + + switch (lib) { + case VecLib::Mvec: { + populateCommonPatterns(patterns); + break; + } + case VecLib::Sleef: { + populateCommonPatterns(patterns); + populatePatternsForOp( + patterns, SleefNameGenerator("expm1"), vec_size_in_bits); + populatePatternsForOp( + patterns, SleefNameGenerator("floor", /*ulp=*/0), vec_size_in_bits); + populatePatternsForOp( + patterns, SleefNameGenerator("sqrt", /*ulp=*/5), vec_size_in_bits); + populatePatternsForOp( + patterns, SleefNameGenerator("trunc", /*ulp=*/0), vec_size_in_bits); + break; + } + } + + patterns.add>( + patterns.getContext(), vec_size_in_bits); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + } + + template + void populateCommonPatterns(RewritePatternSet &patterns) const { + populatePatternsForOp(patterns, VecFnNameGenerator("acos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("acosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atanh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cbrt"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("erf"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log10"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log1p"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tanh"), + vec_size_in_bits); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> +createMathToVecLibPass(VecLib lib, std::set cpu_features) { + return std::make_unique(lib, cpu_features); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000000..a3fbf20a713e --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,350 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MEMORYOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct ExtractMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto memRefTy = cast(op.getType()); + auto rank = memRefTy.getRank(); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + auto memRefStructFields = + cast(memRefStructTy).getBody(); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { + auto valueTy = memRefStructFields[idxTo]; + Value val = rewriter.create( + loc, valueTy, tensorPtrStruct, idxFrom); + return rewriter.create(loc, memRefStructTy, to, val, + idxTo); + }; + + Value res = undef(memRefStructTy); + // Copy base. + res = copyValue(res, 0, 1); + // Use 0 offset. + res = rewriter.create(loc, memRefStructTy, res, + i64_val(0), 2); + // Copy shape. + res = copyValue(res, 2, 3); + // Copy strides. + res = copyValue(res, 3, 4); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct ExtractIndicesOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto rank = op.getNumResults(); + auto i64Ty = IntegerType::get(getContext(), 64); + SmallVector indices; + + for (int64_t i = 0; i < rank; i++) { + indices.push_back(rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{1, i})); + } + + rewriter.replaceOp(op, indices); + + return success(); + } +}; + +struct PtrToMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PtrToMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getSrc()); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + + Value res = undef(memRefStructTy); + res = + rewriter.create(loc, memRefStructTy, res, ptr, 1); + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct MakeTensorPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto structTy = getTypeConverter()->convertType(op.getType()); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto insertArray = [&](Value structVal, auto values, int64_t idx, + Type zextTo = nullptr) { + for (int64_t i = 0; i < static_cast(values.size()); ++i) { + Value val = values[i]; + if (zextTo) + val = rewriter.create(loc, zextTo, val); + structVal = rewriter.create( + loc, structTy, structVal, val, SmallVector{idx, i}); + } + return structVal; + }; + + Value res = undef(structTy); + // 0 - base pointer. + auto base = rewriter.getRemappedValue(op.getBase()); + res = rewriter.create(loc, structTy, res, base, 0); + // 1 - array for offsets. Promote values to i64. + res = insertArray(res, op.getOffsets(), 1, i64Ty); + // 2 - array for shape. + res = insertArray(res, op.getShape(), 2); + // 3 - array for strides. + res = insertArray(res, op.getStrides(), 3); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct AdvanceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i64Ty = IntegerType::get(getContext(), 64); + Value res = rewriter.getRemappedValue(op.getPtr()); + Type structTy = res.getType(); + auto offsets = op.getOffsets(); + + for (int64_t i = 0; i < offsets.size(); ++i) { + auto oldOffset = rewriter.create( + loc, i64Ty, res, SmallVector{1, i}); + auto step = rewriter.create(loc, i64Ty, offsets[i]); + auto newOffset = rewriter.create(loc, oldOffset, step); + res = rewriter.create(loc, structTy, res, newOffset, + SmallVector{1, i}); + } + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type ptrTy = LLVM::LLVMPointerType::get(getContext()); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, + op.getIsVolatile()); + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value val = rewriter.getRemappedValue(op.getValue()); + rewriter.replaceOpWithNewOp(op, val, ptr); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expect only scalar pointers here. + assert(isa(op.getType())); + auto ptrTy = cast(op.getPtr().getType()); + Type elemTy = getTypeConverter()->convertType(ptrTy.getPointeeType()); + Type resTy = getTypeConverter()->convertType(ptrTy); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + rewriter.replaceOpWithNewOp(op, resTy, elemTy, ptr, offset); + return success(); + } +}; + +struct PtrBitcastConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + assert(isa(op.getType())); + assert(isa(op.getSrc().getType())); + Value src = rewriter.getRemappedValue(op.getSrc()); + rewriter.replaceOp(op, src); + return success(); + } +}; + +struct PtrSelectConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + if (!isa(op.getType())) + return failure(); + + Value trueVal = rewriter.getRemappedValue(op.getTrueValue()); + Value falseVal = rewriter.getRemappedValue(op.getFalseValue()); + Value cond = rewriter.getRemappedValue(op.getCondition()); + rewriter.replaceOpWithNewOp(op, cond, trueVal, falseVal); + return success(); + } +}; + +struct MemoryOpToLLVM + : public triton::impl::MemoryOpToLLVMBase { + using MemoryOpToLLVMBase::MemoryOpToLLVMBase; + + MemoryOpToLLVM() : MemoryOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createMemoryOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000000..f9a02592f5d5 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -0,0 +1,53 @@ +#include "TypeConverter.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" + +using namespace mlir; +using namespace mlir::triton; + +TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([this](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); + addConversion([&](amx::TileType type) { + return LLVM::LLVMX86AMXType::get(type.getContext()); + }); +} + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + // struct { + // ptr base_ptr; + // array offsets; + // array shape; + // array strides; + // } + auto tensorTy = cast(pointeeType); + auto rank = tensorTy.getShape().size(); + auto i64Ty = IntegerType::get(ctx, 64); + SmallVector types; + types.push_back(LLVM::LLVMPointerType::get(ctx)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx); +} + +Type TritonCPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + if (isa(type.getElementType())) + return VectorType::get(type.getShape(), + IntegerType::get(type.getContext(), 64)); + llvm_unreachable("No tensor types are expected in TTCIR"); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h new file mode 100644 index 000000000000..02123796ff37 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -0,0 +1,23 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); +}; + +#endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp new file mode 100644 index 000000000000..e783497bd951 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp @@ -0,0 +1,32 @@ +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton::cpu { + +Value getProgramId(mlir::FunctionOpInterface funcOp, int axis) { + auto args = funcOp.getArguments(); + assert(funcOp && args.size() >= 6); + assert(axis >= 0 && axis < 3); + + // The first three of the last six args are x, y, z program ids. + auto argIdx = args.size() - 6 + axis; + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + return args[argIdx]; +} + +Value getNumPrograms(mlir::FunctionOpInterface funcOp, int axis) { + auto args = funcOp.getArguments(); + assert(funcOp && args.size() >= 6); + assert(axis >= 0 && axis < 3); + + // The last three of the args are gridX, gridY, gridZ (bounds) of grid. + auto argIdx = args.size() - 3 + axis; + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + return args[argIdx]; +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Utility.h b/third_party/cpu/lib/TritonCPUToLLVM/Utility.h new file mode 100644 index 000000000000..53ffcc6651ff --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Utility.h @@ -0,0 +1,14 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::cpu { + +Value getProgramId(mlir::FunctionOpInterface funcOp, int axis); +Value getNumPrograms(mlir::FunctionOpInterface funcOp, int axis); + +} // namespace mlir::triton::cpu + +#endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..c6e9b4ed69e6 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonCPUTransforms + ConvertDotOp/ConvertDotCommon.cpp + ConvertDotOp/ConvertDotGeneric.cpp + ConvertDotOp/ConvertDotToAMX.cpp + ConvertDotOp/ConvertDotToFMA.cpp + Canonicalize.cpp + ConvertDotProduct.cpp + ConvertUnsupportedOps.cpp + DecomposeFpConversions.cpp + OptimizeMasks.cpp + + DEPENDS + TritonCPUTransformsPassIncGen +) diff --git a/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp new file mode 100644 index 000000000000..65fed92d2b50 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp @@ -0,0 +1,110 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CANONICALIZE +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// Fold transfer read and the following shape cast that removes heading +// dimensions with size 1. +struct FoldReadShapeCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + if (!op->hasOneUse()) + return failure(); + + auto permMap = op.getPermutationMap(); + if (!permMap.isMinorIdentity()) + return failure(); + + auto reshape = dyn_cast(*op->user_begin()); + if (!reshape) + return failure(); + + VectorType ty = cast(op.getType()); + VectorType dstTy = cast(reshape.getType()); + if (ty.getRank() <= dstTy.getRank()) + return failure(); + + // Check all removed dimensions have size 1. + if (!all_of(drop_end(ty.getShape(), dstTy.getRank()), + [](int64_t val) { return val == 1; })) + return failure(); + + // Check shape prefix matches the resulting type. + if (!equal(drop_begin(ty.getShape(), ty.getRank() - dstTy.getRank()), + dstTy.getShape())) + return failure(); + + auto inBounds = op.getInBounds(); + if (std::any_of(inBounds.begin(), inBounds.end() - dstTy.getRank(), + [](Attribute attr) { + return !cast(attr).getValue(); + })) + return failure(); + + // Fold read and shape cast into a single read. + auto newPermMap = permMap.getMinorIdentityMap( + permMap.getNumDims(), dstTy.getRank(), getContext()); + auto newInBounds = rewriter.getArrayAttr(SmallVector(drop_begin( + op.getInBounds().getValue(), ty.getRank() - dstTy.getRank()))); + auto newRead = rewriter.create( + loc, dstTy, op.getSource(), op.getIndices(), newPermMap, + op.getPadding(), op.getMask(), newInBounds); + rewriter.replaceOp(reshape, newRead); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct Canonicalize : public triton::cpu::impl::CanonicalizeBase { + Canonicalize() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createCanonicalize() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp new file mode 100644 index 000000000000..4ad5de863fb4 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -0,0 +1,190 @@ +#include "ConvertDotCommon.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { + +bool isLoopCarriedAcc(Value acc) { + LDBG("Check if accumulator can be held in tiles: " << acc); + if (!acc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + for (auto op : acc.getUsers()) + LDBG(" " << *op); + return false; + } + + auto blockArg = dyn_cast(acc); + if (!blockArg) { + LDBG(" No. Not a block argument."); + return false; + } + + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!forOp) { + LDBG(" No. Not in a for-loop."); + return false; + } + + blockArg.getArgNumber(); + + Value updAcc = acc.getUsers().begin()->getResult(0); + if (!updAcc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + return false; + } + + auto &updAccUse = *updAcc.getUses().begin(); + if (!isa(updAccUse.getOwner()) || + updAccUse.getOperandNumber() != + (blockArg.getArgNumber() - forOp.getNumInductionVars())) { + LDBG(" No. Loop carried dependency not detected."); + return false; + } + + LDBG(" Yes."); + return true; +} + +Value getInitAccValue(Value val) { + auto blockArg = cast(val); + auto forOp = cast(blockArg.getOwner()->getParentOp()); + int initValIdx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + return forOp.getInitArgs()[initValIdx]; +} + +MemBuffer findInputBuffer(Value val, bool allowTransposed) { + MemBuffer buf; + + if (allowTransposed) { + auto transposeOp = val.getDefiningOp(); + if (transposeOp) { + val = transposeOp.getVector(); + buf.transposed = true; + } + } + + auto valLoad = val.getDefiningOp(); + if (!valLoad || hasMaskOrBoundsCheck(valLoad)) { + LDBG("Couldn't find a buffer with input: " << val); + return buf; + } + + buf.memRef = valLoad.getSource(); + buf.indices = valLoad.getIndices(); + LLVM_DEBUG( + DBGS() << "Found buffer with input: " << val << "\n"; + DBGS() << " MemRef: " << buf.memRef << "\n"; DBGS() << " Indices: "; + llvm::interleaveComma(buf.indices, llvm::dbgs()); llvm::dbgs() << "\n"); + + auto forOp = dyn_cast(valLoad->getParentOp()); + if (!forOp) { + LDBG(" Skip steps. Not in a for-loop."); + return buf; + } + + auto extractMemRef = buf.memRef.getDefiningOp(); + if (!extractMemRef) { + LDBG(" Skip steps. No ExtractMemRefOp."); + return buf; + } + + ExtractIndicesOp extractIndices; + for (auto index : buf.indices) { + auto def = index.getDefiningOp(); + if (!def || (extractIndices && def != extractIndices)) { + LDBG(" Skip steps. No ExtractIndicesOp."); + return buf; + } + extractIndices = def; + } + + if (extractMemRef.getSrc() != extractIndices.getSrc()) { + LDBG(" Skip steps. Mismatched ExtractMemRefOp and ExtractIndicesOp."); + return buf; + } + + BlockArgument blockPtrArg = dyn_cast(extractMemRef.getSrc()); + if (!blockPtrArg) { + LDBG(" Skip steps. No block pointer arg."); + return buf; + } + + OpOperand *yieldOp = forOp.getTiedLoopYieldedValue(blockPtrArg); + if (!yieldOp) { + LDBG(" Skip steps. No block pointer in yield."); + return buf; + } + + auto advance = yieldOp->get().getDefiningOp(); + if (!advance) { + LDBG(" Skip steps. No AdvanceOp."); + return buf; + } + + if (advance.getPtr() != blockPtrArg) { + LDBG(" Skip steps. AdvanceOp doesn't use block pointer arg."); + return buf; + } + + buf.step = advance.getOffsets(); + LLVM_DEBUG(DBGS() << " Step: "; + llvm::interleaveComma(buf.step, llvm::dbgs()); + llvm::dbgs() << "\n"); + + return buf; +} + +Value maybeCast(Location loc, Value val, Type dstElemTy, + PatternRewriter &rewriter) { + VectorType srcTy = cast(val.getType()); + if (srcTy.getElementType() == dstElemTy) + return val; + + VectorType dstTy = srcTy.cloneWith(std::nullopt, dstElemTy); + if (srcTy.getElementType().isInteger()) { + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); + } + + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); +} + +MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, + Operation *allocaPoint, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(allocaPoint); + auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + Value memRef = rewriter.create( + loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(2, zeroIdx); + return {memRef, indices}; +} + +Value shiftIndex(Location loc, Value index, int64_t offs, + PatternRewriter &rewriter) { + if (!offs) + return index; + + // Do constant folding right away here for better code readability + // after the pass. + auto cstOp = dyn_cast(index.getDefiningOp()); + if (cstOp) { + int64_t oldVal = cast(cstOp.getValue()).getInt(); + return rewriter.create(loc, oldVal + offs); + } + + Value offsVal = rewriter.create(loc, offs); + return rewriter.create(loc, index.getType(), index, offsVal); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h new file mode 100644 index 000000000000..e26529d91882 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -0,0 +1,72 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#define DEBUG_TYPE "triton-cpu-dot-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace cpu { + +// This structure describes input/output buffer. +struct MemBuffer { + Value memRef; + SmallVector indices; + // If buffer is accessed in a loop and indices are advanced + // on each iteration, then step can hold those index offsets. + // Empty step doesn't mean indices are loop invariant. + SmallVector step; + // True if buffer holds transposed value. + bool transposed = false; + + bool empty() const { return !memRef; } +}; + +// Check if accumulator value is updated in a loop and has no other +// usages than a dot op that updates it. Loads, stores, and casts +// for such accumulator can be done outside of the loop. +bool isLoopCarriedAcc(Value acc); + +// Get initial value for a loop-carried accumulator. +Value getInitAccValue(Value val); + +// Check if vector transfer read/write operation uses a mask +// or involves a bounds check. +template bool hasMaskOrBoundsCheck(T op) { + auto inBounds = op.getInBounds(); + Value mask = op.getMask(); + bool hasBoundsCheck = + std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { + return !cast(attr).getValue(); + }); + return hasBoundsCheck || mask; +} + +// Search for a buffer holding required value. If allowTransposed is true, +// then buffer is allowed to hold both transposed and not transposed value. +// Return empty buffer if no memory holding value was found. +MemBuffer findInputBuffer(Value val, bool allowTransposed = false); + +// Cast vector to a specified element type using ext or trunc +// operations. Return the original value if it already matches +// the required element type. +Value maybeCast(Location loc, Value val, Type dstElemTy, + PatternRewriter &rewriter); + +// Allocate temporary buffer on stack for specified vector type. +MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, + Operation *allocaPoint, PatternRewriter &rewriter); + +// Move index by specified offset. Do constannt folding if possible. +Value shiftIndex(Location loc, Value index, int64_t offs, + PatternRewriter &rewriter); + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp new file mode 100644 index 000000000000..9465e67b36cf --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp @@ -0,0 +1,132 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTGENERIC +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DotConversionTarget : public ConversionTarget { +public: + explicit DotConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cpu::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value a = op.getA(); + Value b = op.getB(); + Value c = op.getC(); + VectorType aType = cast(a.getType()); + VectorType bType = cast(b.getType()); + VectorType cType = cast(c.getType()); + + uint32_t rank = aType.getRank(); + if (rank == 2) { + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } else if (rank == 3) { + auto aMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } + + return failure(); + } + + SmallVector deinterleave(Location loc, ArrayRef vals, + ConversionPatternRewriter &rewriter) const { + SmallVector res; + for (auto &val : vals) { + auto op = rewriter.create(loc, val); + res.push_back(op.getResult(0)); + res.push_back(op.getResult(1)); + } + return res; + } +}; + +struct ConvertDotGeneric + : public triton::cpu::impl::ConvertDotGenericBase { + using ConvertDotGenericBase::ConvertDotGenericBase; + + ConvertDotGeneric() : ConvertDotGenericBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + DotConversionTarget convTarget(*context); + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotGeneric() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp new file mode 100644 index 000000000000..1b6dd9269ac1 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -0,0 +1,882 @@ +#include "ConvertDotCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "include/triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTTOAMX +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This structure is used to hold candidates for conversion to AMX +// Mul[F|I]Op operations. +struct AmxDotOpCandidate { + // Operation to convert. + cpu::DotOp op; + // Available LHS, RHS, and accumulator types are limited in AMX and we might + // require additional casts. Here we keep actual element types used by LHS, + // RHS, and accumulator in AMX tiles. + Type lhsTileElemTy; + Type rhsTileElemTy; + Type accTileElemTy; + // AMX tile row size is limited by 64 bytes, so M and N dimensions are limited + // by 16 because accumulator always has 4-byte elements. K dimension for tiles + // is limited by 64 / . Here we keep actual tile sizes. + int64_t tileM; + int64_t tileN; + int64_t tileK; + // We have a limited number of available tiles, so if input/output is too + // big to fit available tiles, we need to split them into blocks. Here we + // keep a number of tiles in accumulator block. K dimension for input blocks + // is always 1 tile now. + int64_t tilesInBlockM; + int64_t tilesInBlockN; + // If accumulator is updated in a loop, then this flag indicates if we + // should keep it in tiles the whole loop and move back to vectors only + // after the loop. + bool keepAccOnTiles = false; + // If we want to keep accumulator in tiles but it's too big, then we might + // keep it bufferized instead. + bool keepAccInBuf = false; + // If resulting tiles are not required to be trasfered to vectors and can be + // directly stored to the output memory instead, then this field holds a + // buffer to use. + MemBuffer outBuf; + // If output buffer is used then keep the original vector store here. + Operation *origStore = nullptr; +}; + +// Check if input and output types can be handled by AMX (possibly, using +// additional casts for input/output). Returns true if AMX usage is possible. +// In this case, tile element type fields of the candidate structure are +// filled with actual types to be used in lowering. +bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, + Type resElemTy, bool supportInt8, bool supportFp16, + bool supportBf16, AmxDotOpCandidate &candidate) { + MLIRContext *ctx = lhsElemTy.getContext(); + if (lhsElemTy.isInteger()) { + if (!supportInt8) { + LDBG("Drop candidate because AMX_INT8 is not available."); + return false; + } + + // For integer case only i8 is allowed for LHS and RHS. + if (!lhsElemTy.isInteger(8) || !rhsElemTy.isInteger(8)) { + LDBG("Drop candidate with unsupported input integer type."); + return false; + } + + // Accumulator should be i32. If it's smaller, we will use casts. + if (!accElemTy.isInteger() || accElemTy.getIntOrFloatBitWidth() > 32 || + !resElemTy.isInteger() || resElemTy.getIntOrFloatBitWidth() > 32) { + LDBG("Drop candidate with unsupported output integer type."); + return false; + } + + candidate.lhsTileElemTy = IntegerType::get(ctx, 8); + candidate.rhsTileElemTy = IntegerType::get(ctx, 8); + candidate.accTileElemTy = IntegerType::get(ctx, 32); + + return true; + } + + // FP case. Expect no integer args or result. + if (rhsElemTy.isInteger() || accElemTy.isInteger() || resElemTy.isInteger()) { + LDBG("Drop candidate with mixed int/fp types."); + return false; + } + + // For fp case LHS and RHS types should match and can be either FP16 or + // BF16. + if (lhsElemTy.getIntOrFloatBitWidth() > 16 || + rhsElemTy.getIntOrFloatBitWidth() > 16) { + LDBG("Drop candidate with unsupported input fp type."); + return false; + } + + // Try to find a common input type. There is currently no support + // for FP8 types, so promote them to FP16/BF16. + Type commonInputElemTy; + if (lhsElemTy.getIntOrFloatBitWidth() == 16) { + commonInputElemTy = lhsElemTy; + if (rhsElemTy.getIntOrFloatBitWidth() == 16 && + rhsElemTy != commonInputElemTy) { + LDBG("Drop candidate with mismatched input types."); + return false; + } + } else if (rhsElemTy.getIntOrFloatBitWidth() == 16) + commonInputElemTy = rhsElemTy; + // Both inputs are FP8, choose 16-bit FP type to use. + else if (supportBf16) + commonInputElemTy = BFloat16Type::get(ctx); + else + commonInputElemTy = Float16Type::get(ctx); + + if (commonInputElemTy.isF16() && !supportFp16) { + LDBG("Drop candidate because AMX_FP16 is not available."); + return false; + } + + if (commonInputElemTy.isBF16() && !supportBf16) { + LDBG("Drop candidate because AMX_BF16 is not available."); + return false; + } + + // Accumulator type should be FP32, we can use casts if it is smaller. + if (accElemTy.getIntOrFloatBitWidth() > 32) { + LDBG("Drop candidate with unsupported accumulator type."); + return false; + } + + candidate.lhsTileElemTy = commonInputElemTy; + candidate.rhsTileElemTy = commonInputElemTy; + candidate.accTileElemTy = Float32Type::get(ctx); + + return true; +} + +// Check input shapes. Currently, support only 2D cases and ignore small +// inputs. +bool checkInputShapes(VectorType lhsTy, VectorType resTy) { + if (lhsTy.getRank() != 2) + return false; + + if (lhsTy.getDimSize(0) < 8 || lhsTy.getDimSize(1) < 8 || + resTy.getDimSize(1) < 8) + return false; + + return true; +} + +// Return a value that holds the resulting loop carried accumulator value. +// It's one of ForOp's results. +Value getResValueForLoopCarriedAcc(cpu::DotOp op) { + Value updAcc = op.getResult(); + auto forOp = dyn_cast(op->getParentOp()); + auto &use = *updAcc.getUses().begin(); + return forOp.getResult(use.getOperandNumber()); +} + +// Choose tile and block sizes for the candidate. Tile sizes are determined +// by input shapes and types. Block sizes are chosen to minimize number of +// tile loads/stores including tile register spills. +void setupBlockAndTileSizes(ArrayRef lhsShape, + ArrayRef resShape, + AmxDotOpCandidate &candidate) { + int64_t m = resShape[0]; + int64_t n = resShape[1]; + int64_t k = lhsShape[1]; + int64_t tileM = std::min(m, (int64_t)16); + int64_t tileN = std::min(n, (int64_t)16); + int64_t tileK = std::min( + k, (int64_t)512 / candidate.lhsTileElemTy.getIntOrFloatBitWidth()); + + int64_t accBlocksM = m / tileM; + int64_t accBlocksN = n / tileN; + + // All these sizes are power of 2. We have 8 tile registers and + // cannot use them all for accumulator. So, we will use up to 4 + // tiles for accumulator in a single block. + while (accBlocksM * accBlocksN > 4) { + if (accBlocksM > accBlocksN) + accBlocksM /= 2; + else + accBlocksN /= 2; + } + + candidate.tileM = tileM; + candidate.tileN = tileN; + candidate.tileK = tileK; + candidate.tilesInBlockM = accBlocksM; + candidate.tilesInBlockN = accBlocksN; +} + +// Check if vector transfer read/write operation uses a mask +// or involves a bounds check. +template bool hasMaskOrBoundsCheck(T op) { + auto inBounds = op.getInBounds(); + Value mask = op.getMask(); + bool hasBoundsCheck = + std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { + return !cast(attr).getValue(); + }); + return hasBoundsCheck || mask; +} + +// Check if a value is used only for a store and that this store can be +// replaced with tile stores. In this case fill appropriate fields in the +// candidate structure. +void findOutputBuffer(Value val, AmxDotOpCandidate &candidate) { + if (val.hasOneUse()) { + auto store = dyn_cast(*val.user_begin()); + if (store && !hasMaskOrBoundsCheck(store)) + candidate.outBuf = MemBuffer{store.getSource(), store.getIndices()}; + candidate.origStore = store; + } +} + +// Check if specified ContractionOp can be lowered to AMX operations. +// If conversion is possible, then true is returned and candidate +// structure is filled with detailed transformation info. +bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, + bool supportBf16, AmxDotOpCandidate &candidate) { + MLIRContext *ctx = op.getContext(); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getType()); + + LDBG("Considering candidate op: " << op); + + // Check if input and output types match available hardware capabilities. + // If check is successful then tile element types are filled with types + // to use in AMX operations. + if (!checkElemTypes(lhsTy.getElementType(), rhsTy.getElementType(), + accTy.getElementType(), resTy.getElementType(), + supportInt8, supportFp16, supportBf16, candidate)) + return false; + + // Check input shapes. + if (!checkInputShapes(lhsTy, resTy)) + return false; + + candidate.op = op; + setupBlockAndTileSizes(lhsTy.getShape(), resTy.getShape(), candidate); + candidate.keepAccOnTiles = isLoopCarriedAcc(op.getC()); + + // Can't keep acc in a tile the whole loop right now: + // https://github.com/llvm/llvm-project/issues/109481 + if (candidate.keepAccOnTiles) { + // We might not have enough tiles to hold the whole accumulator. If we + // have more than one block, keep it in a bufffer. + if (candidate.tilesInBlockM * candidate.tileM < resTy.getDimSize(0) || + candidate.tilesInBlockN * candidate.tileN < resTy.getDimSize(1)) { + LDBG("Accumulator is too big to keep on tiles. Keep it bufferized " + "insterad."); + candidate.keepAccOnTiles = false; + candidate.keepAccInBuf = true; + } else { + findOutputBuffer(getResValueForLoopCarriedAcc(op), candidate); + } + } else { + findOutputBuffer(op.getResult(), candidate); + } + + return true; +} + +template T getSwizzledRhsTileType(T origTileType) { + int64_t rowsPerGroup = 32 / origTileType.getElementTypeBitWidth(); + SmallVector shape({origTileType.getDimSize(0) / rowsPerGroup, + origTileType.getDimSize(1) * rowsPerGroup}); + return origTileType.cloneWith(shape, origTileType.getElementType()); +} + +// In AMX, element values shoud be packed to 32-bit groups that would be +// multiplied elementwise with following accumulation. It means that RHS +// needs to be pre-packed. E.g. for the following input +// B(0,0) B(0,1) B(0,2) ... B(0,15) +// B(1,0) B(1,1) B(1,2) ... B(1,15) +// B(2,0) B(2,1) B(2,2) ... B(2,15) +// B(3,0) B(3,1) B(3,2) ... B(3,15) +// and BF16/FP16 type we need to transform it to +// B(0,0) B(1,0) B(0,1), B(1,1) ... B(0,15) B(1,15) +// B(2,0) B(3,0) B(2,1), B(3,1) ... B(2,15) B(3,15) +// so that original columns are 32-bits now. In case of int8 type, the +// result would be: +// B(0,0) B(1,0) B(2,0), B(3,0) ... B(0,15) B(1,15), B(2,15) B(3,15) +void interleaveAndStore(Location loc, Value val, Value buf, + PatternRewriter &rewriter) { + LDBG("Repacking operand before storing to a buffer."); + VectorType valTy = cast(val.getType()); + int64_t rowsPerGroup = 32 / valTy.getElementTypeBitWidth(); + assert(rowsPerGroup == 2 || rowsPerGroup == 4); + assert(valTy.getDimSize(0) % rowsPerGroup == 0); + Value zeroIdx = index_cst(0); + for (int64_t i = 0; i < valTy.getDimSize(0); i += rowsPerGroup) { + Value row1, row2; + if (rowsPerGroup == 2) { + row1 = op_extract(val, i); + row2 = op_extract(val, i + 1); + } else { + row1 = op_interleave(op_extract(val, i), op_extract(val, i + 2)); + row2 = op_interleave(op_extract(val, i + 1), op_extract(val, i + 3)); + } + Value shuffled = op_interleave(row1, row2); + Value idx = index_cst(i / rowsPerGroup); + op_store(shuffled, buf, SmallVector({idx, zeroIdx})); + } +} + +Value loadWithPrefetch(Location loc, VectorType ty, Value memRef, + ArrayRef indices, ArrayRef step, + PatternRewriter &rewriter) { + Value res = op_read(ty, memRef, indices); + if (!step.empty()) { + SmallVector prefetchIndices; + for (int64_t i = 0; i < indices.size(); ++i) { + prefetchIndices.push_back( + op_addi(indices[i], rewriter.create( + loc, rewriter.getIndexType(), step[i]))); + } + rewriter.create(loc, memRef, prefetchIndices, false, 1, + true); + } + return res; +} + +// Copy tensor with packing using for-loop. See interleaveAndStore for more +// details. +void copyWithInterleave(Location loc, VectorType srcTy, const MemBuffer &src, + const MemBuffer &dst, PatternRewriter &rewriter) { + int64_t rowsPerGroup = 32 / srcTy.getElementTypeBitWidth(); + Value lower = index_cst(0); + Value upper = index_cst(srcTy.getDimSize(0) / rowsPerGroup); + Value one = index_cst(1); + Value rowsPerGroupVal = index_cst(rowsPerGroup); + VectorType srcVecTy = + VectorType::get({srcTy.getDimSize(1)}, srcTy.getElementType()); + auto forOp = rewriter.create(loc, lower, upper, one); + Value ivVal = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector srcIndices = src.indices; + int64_t mDimIdx = srcIndices.size() - 2; + Value scaledM = op_muli(ivVal, rowsPerGroupVal); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], scaledM); + Value row1 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row2 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + if (rowsPerGroup == 4) { + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row3 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row4 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + row1 = op_interleave(row1, row3); + row2 = op_interleave(row2, row4); + } + Value shuffled = op_interleave(row1, row2); + SmallVector dstIndices = dst.indices; + dstIndices[dstIndices.size() - 2] = + op_addi(dstIndices[dstIndices.size() - 2], ivVal); + op_write(shuffled, dst.memRef, dstIndices); + rewriter.setInsertionPointAfter(forOp); +} + +// Prepare temporary buffers to be used for tile loads. If the original +// value can be directly loaded to tiles from its original memory, then +// use it instead. Return empty buffer if source value is all zeros and +// skipForZeros is set. +// +// If interleave flag is set, then pre-pack RHS before store. See +// interleaveAndStore for more details. +MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, + bool skipForZeros, bool readOnly, + Operation *allocaPoint, + PatternRewriter &rewriter) { + LDBG("Preparing buffer (interleave=" << interleave + << ") for a vector: " << val); + auto vecTy = cast(val.getType()); + MemBuffer inputBuf = findInputBuffer(val); + if (!inputBuf.empty()) { + if (interleave) { + LDBG(" Copying from the original memref with interleave: " + << inputBuf.memRef); + auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), + allocaPoint, rewriter); + copyWithInterleave(loc, vecTy, inputBuf, tmpBuf, rewriter); + return tmpBuf; + } + LDBG(" Reusing the original memref for a buffer: " << inputBuf.memRef); + return inputBuf; + } + + if (skipForZeros && isZeroConst(val)) { + LDBG("Skip buffer for zero vector."); + return {}; + } + + if (interleave) + vecTy = getSwizzledRhsTileType(vecTy); + MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + + if (interleave) { + interleaveAndStore(loc, val, buf.memRef, rewriter); + } else { + op_write(val, buf.memRef, buf.indices); + } + + return buf; +} + +// Return a buffer where the final result should be stored. If result can +// be directly stored to the output memory, then it is used as an output +// buffer. Otherwise, re-use accumulator buffer or create a new one. +MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, + const MemBuffer &outBuf, Operation *allocaPoint, + PatternRewriter &rewriter) { + if (!outBuf.empty()) { + LDBG("Output memory will be used for direct tile stores."); + return outBuf; + } + + if (!accBuf.empty()) { + LDBG("Result will be stored to accumulator buffer."); + return accBuf; + } + + LDBG("Allocating buffer for the result."); + return allocateTmpBuffer(loc, cast(val.getType()), allocaPoint, + rewriter); +} + +SmallVector shiftIndices(Location loc, ArrayRef indices, + amx::TileType tileTy, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { + int64_t blockOffsM = blockM * tilesInBlockM * tileTy.getDimSize(0); + int64_t blockOffsN = blockN * tilesInBlockN * tileTy.getDimSize(1); + int64_t tileOffsM = blockOffsM + tileM * tileTy.getDimSize(0); + int64_t tileOffsN = blockOffsN + tileN * tileTy.getDimSize(1); + SmallVector res(indices.begin(), indices.end() - 2); + res.push_back(shiftIndex(loc, *(indices.end() - 2), tileOffsM, rewriter)); + res.push_back(shiftIndex(loc, *(indices.end() - 1), tileOffsN, rewriter)); + return res; +} + +Value loadTile(Location loc, amx::TileType tileTy, const MemBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { + auto indices = + shiftIndices(loc, buf.indices, tileTy, tilesInBlockM, tilesInBlockN, + blockM, blockN, tileM, tileN, rewriter); + return rewriter.create(loc, tileTy, buf.memRef, indices); +} + +void storeTile(Location loc, amx::TileType tileTy, Value val, + const MemBuffer &buf, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, int64_t blockN, + int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { + auto indices = + shiftIndices(loc, buf.indices, tileTy, tilesInBlockM, tilesInBlockN, + blockM, blockN, tileM, tileN, rewriter); + rewriter.create(loc, buf.memRef, indices, val); +} + +SmallVector> +loadBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, PatternRewriter &rewriter) { + SmallVector> res(tilesInBlockM); + for (int64_t m = 0; m < tilesInBlockM; ++m) { + for (int64_t n = 0; n < tilesInBlockN; ++n) { + Value tile = buf.memRef + ? loadTile(loc, tileTy, buf, tilesInBlockM, + tilesInBlockN, blockM, blockN, m, n, rewriter) + : rewriter.create(loc, tileTy); + res[m].push_back(tile); + } + } + return res; +} + +void storeBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, + int64_t blockM, int64_t blockN, + const SmallVector> &tiles, + PatternRewriter &rewriter) { + int64_t tilesInBlockM = tiles.size(); + int64_t tilesInBlockN = tiles[0].size(); + for (int64_t m = 0; m < tilesInBlockM; ++m) { + for (int64_t n = 0; n < tilesInBlockN; ++n) { + storeTile(loc, tileTy, tiles[m][n], buf, tilesInBlockM, tilesInBlockN, + blockM, blockN, m, n, rewriter); + } + } +} + +// Multiply two blocks. LHS block is preloaded to tiles with the following +// iteration over RHS. Accumulator values are updated in accTiles. +// Optionally, results can also be stored to accBuf. +void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, + amx::TileType rhsTileTy, amx::TileType accTileTy, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, + int64_t blockN, int64_t blockK, + int64_t tilesInBlockM, int64_t tilesInBlockN, + SmallVector> &accTiles, + bool storeResult, PatternRewriter &rewriter) { + bool isInteger = accTileTy.getElementType().isInteger(); + SmallVector> lhsTiles = loadBlockTiles( + loc, lhsTileTy, lhsBuf, tilesInBlockM, 1, blockM, blockK, rewriter); + + for (int64_t tileN = 0; tileN < tilesInBlockN; ++tileN) { + Value rhsTile = loadTile(loc, rhsTileTy, rhsBuf, 1, tilesInBlockN, blockK, + blockN, 0, tileN, rewriter); + + for (int64_t tileM = 0; tileM < tilesInBlockM; ++tileM) { + if (isInteger) + accTiles[tileM][tileN] = + rewriter.create(loc, accTileTy, lhsTiles[tileM][0], + rhsTile, accTiles[tileM][tileN]); + else + accTiles[tileM][tileN] = + rewriter.create(loc, accTileTy, lhsTiles[tileM][0], + rhsTile, accTiles[tileM][tileN]); + + // Insert store here to better mix stores with multiplications. + if (storeResult) { + storeTile(loc, accTileTy, accTiles[tileM][tileN], accBuf, tilesInBlockM, + tilesInBlockN, blockM, blockN, tileM, tileN, rewriter); + } + } + } +} + +// Similar to multiplyBlocksPreloadLhs but here RHS is preloaded to tiles. +void multiplyBlocksPreloadRhs(Location loc, amx::TileType lhsTileTy, + amx::TileType rhsTileTy, amx::TileType accTileTy, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, + int64_t blockN, int64_t blockK, + int64_t tilesInBlockM, int64_t tilesInBlockN, + SmallVector> &accTiles, + bool storeResult, PatternRewriter &rewriter) { + bool isInteger = accTileTy.getElementType().isInteger(); + SmallVector> rhsTiles = loadBlockTiles( + loc, rhsTileTy, rhsBuf, 1, tilesInBlockN, blockK, blockN, rewriter); + + for (int64_t tileM = 0; tileM < tilesInBlockM; ++tileM) { + Value lhsTile = loadTile(loc, lhsTileTy, lhsBuf, tilesInBlockM, 1, blockM, + blockK, tileM, 0, rewriter); + + for (int64_t tileN = 0; tileN < tilesInBlockN; ++tileN) { + if (isInteger) + accTiles[tileM][tileN] = rewriter.create( + loc, accTileTy, lhsTile, rhsTiles[0][tileN], + accTiles[tileM][tileN]); + else + accTiles[tileM][tileN] = rewriter.create( + loc, accTileTy, lhsTile, rhsTiles[0][tileN], + accTiles[tileM][tileN]); + + // Insert store here to better mix stores with multiplications. + if (storeResult) { + storeTile(loc, accTileTy, accTiles[tileM][tileN], accBuf, tilesInBlockM, + tilesInBlockN, blockM, blockN, tileM, tileN, rewriter); + } + } + } +} + +LogicalResult convertCandidate(AmxDotOpCandidate &candidate, + PatternRewriter &rewriter) { + cpu::DotOp op = candidate.op; + Location loc = op.getLoc(); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getResult().getType()); + amx::TileType lhsTileTy = amx::TileType::get( + SmallVector({candidate.tileM, candidate.tileK}), + candidate.lhsTileElemTy); + amx::TileType rhsTileTy = getSwizzledRhsTileType(amx::TileType::get( + SmallVector({candidate.tileK, candidate.tileN}), + candidate.rhsTileElemTy)); + amx::TileType accTileTy = amx::TileType::get( + SmallVector({candidate.tileM, candidate.tileN}), + candidate.accTileElemTy); + + // If we don't work with a loop and want to directly store tiles into output + // memory, then use the original store as insertion point to have its buffer + // values available for generated code. + if (!candidate.keepAccInBuf && !candidate.keepAccOnTiles && + !candidate.outBuf.empty()) + rewriter.setInsertionPoint(candidate.origStore); + + Operation *allocaPoint = op; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Cast input data if required and prepare input buffer. It might be temporary + // buffers with stored vectors or the original input memory. + Value lhs = maybeCast(loc, op.getA(), candidate.lhsTileElemTy, rewriter); + MemBuffer lhsBuf = + prepareTensorBuffer(loc, lhs, false, false, true, allocaPoint, rewriter); + + Value rhs = maybeCast(loc, op.getB(), candidate.rhsTileElemTy, rewriter); + MemBuffer rhsBuf = + prepareTensorBuffer(loc, rhs, true, false, true, allocaPoint, rewriter); + + Value acc = maybeCast(loc, op.getC(), candidate.accTileElemTy, rewriter); + Value accToStore = acc; + scf::ForOp forOp; + if (candidate.keepAccInBuf || candidate.keepAccOnTiles) { + forOp = cast(op->getParentOp()); + accToStore = getInitAccValue(acc); + } + MemBuffer accBuf; + { + // If accumulator is bufferized then we should move initial values before + // the loop. + OpBuilder::InsertionGuard g(rewriter); + if (candidate.keepAccInBuf) + rewriter.setInsertionPoint(forOp); + accBuf = + prepareTensorBuffer(loc, accToStore, false, !candidate.keepAccInBuf, + false, allocaPoint, rewriter); + } + + MemBuffer resBuf = prepareResultBuffer( + loc, op.getResult(), accBuf, candidate.outBuf, allocaPoint, rewriter); + + SmallVector> accTiles; + SmallVector> accInitTiles; + if (candidate.keepAccOnTiles) { + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + LDBG("Loading accumulator to tiles before the loop."); + accInitTiles = + loadBlockTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, 0, 0, rewriter); + accTiles = accInitTiles; + } + + int64_t blocksInAccM = + accTy.getDimSize(0) / candidate.tileM / candidate.tilesInBlockM; + int64_t blocksInAccN = + accTy.getDimSize(1) / candidate.tileN / candidate.tilesInBlockN; + int64_t tilesInVectorK = lhsTy.getDimSize(1) / candidate.tileK; + for (int64_t blockM = 0; blockM < blocksInAccM; ++blockM) { + for (int64_t blockN = 0; blockN < blocksInAccN; ++blockN) { + if (!candidate.keepAccOnTiles) + accTiles = + loadBlockTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, blockM, blockN, rewriter); + + for (int64_t blocK = 0; blocK < tilesInVectorK; ++blocK) { + // We can store accumulator if it is the last block over K dimension. + // TODO: enable forward store for acc kept in tiles. + bool storeAcc = + !candidate.keepAccOnTiles && (blocK == (tilesInVectorK - 1)); + + // We need to choose which block (LHS or RHS) to keep on tiles. + // E.g. for ACC block 4x1 tiles, LHS block is also 4 tiles, so + // we would use all tile registers trying to keep both ACC and + // LHS blocks on registers. To decrease register pressure, keep + // the smallest block on tiles. + if (candidate.tilesInBlockM <= candidate.tilesInBlockN) + multiplyBlocksPreloadLhs( + loc, lhsTileTy, rhsTileTy, accTileTy, lhsBuf, rhsBuf, resBuf, + blockM, blockN, blocK, candidate.tilesInBlockM, + candidate.tilesInBlockN, accTiles, storeAcc, rewriter); + else + multiplyBlocksPreloadRhs( + loc, lhsTileTy, rhsTileTy, accTileTy, lhsBuf, rhsBuf, resBuf, + blockM, blockN, blocK, candidate.tilesInBlockM, + candidate.tilesInBlockN, accTiles, storeAcc, rewriter); + } + } + } + + if (candidate.keepAccOnTiles) { + // In this case we have the whole accumulator/result on tiles. Loop + // carried dependencies are not in place yet and should be added. + // After the loop, resulting tiles should either be stored to the + // output buffer, or moved to a vector though a temporary buffer. + + // We don't need the original accumulator and contraction op anymore. + // Directly yield orig accumulator value, so it would be later removed + // as unused. The original contraction can be removed right away. + int64_t origResIdx = op.getResult().getUses().begin()->getOperandNumber(); + rewriter.replaceOp(op, op.getC()); + + // Now, replace the loop with a new one to add loop carried dependency for + // accumulator tiles. + LDBG("Rewrite loop to introduce loop carried dependencies for accumulator " + "tiles."); + SmallVector newInitOperands; + SmallVector newYieldedValues; + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + LDBG("Initial value\n " << accInitTiles[m][n] + << "\nis combined with\n " << accTiles[m][n]); + newInitOperands.push_back(accInitTiles[m][n]); + newYieldedValues.push_back(accTiles[m][n]); + } + auto newForOp = cast(*forOp.replaceWithAdditionalYields( + rewriter, newInitOperands, true, + [&newYieldedValues](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return newYieldedValues; + })); + + // The resulting tiles are now in the new loop results. + auto resTiles = newForOp.getResults().take_back(newYieldedValues.size()); + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + accTiles[m][n] = resTiles[m * candidate.tilesInBlockN + n]; + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newForOp); + if (candidate.outBuf.empty()) { + // Move tiles to a vector through a temporary buffer and use it instead + // of the original one. + LDBG("Moving resulting tiles to a vector through memory."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accTileElemTy); + storeBlockTiles(loc, accTileTy, resBuf, 0, 0, accTiles, rewriter); + Value newVal = op_read(resTy, resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(newForOp.getResult(origResIdx), newVal); + } else { + // Store tiles directly to the output buffer and remove the original + // store. + LDBG("Storing resulting tiles to the output memory."); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(candidate.origStore); + storeBlockTiles(loc, accTileTy, candidate.outBuf, 0, 0, accTiles, + rewriter); + rewriter.eraseOp(candidate.origStore); + } + } else if (candidate.keepAccInBuf) { + // The result is in the buffer. We should load it and replace one of the + // loop results. The original contraction op can be removed. + // TODO: should we try to store to the output buffer on the last iteration? + Value loopRes = forOp.getTiedLoopResult(cast(op.getC())); + LDBG( + "Loading buffererized accumulator to a vector to replace loop result."); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(loopRes, newVal); + // Directly yield orig accumulator iter value. It will be removed as unused + // later. + rewriter.replaceOp(op, op.getC()); + } else if (candidate.outBuf.empty()) { + // The result is in the buffer. We should load it and replace the original + // constraction result. + LDBG("Loading the result to a vector to replace orig op result."); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceOp(op, newVal); + } else { + // The result is already in the output buffer. We just need to remove the + // original contraction and store operation. + LDBG("Removing original operation and its use."); + rewriter.eraseOp(candidate.origStore); + rewriter.eraseOp(op); + } + + return success(); +} + +struct ConvertDotToAMX + : public triton::cpu::impl::ConvertDotToAMXBase { + ConvertDotToAMX() = default; + ConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16) { + this->convertInt8 = convertInt8; + this->convertFp16 = convertFp16; + this->convertBf16 = convertBf16; + } + + void runOnOperation() override { + if (!convertInt8 && !convertFp16 && !convertBf16) + return; + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + SmallVector candidates; + mod->walk([this, &candidates](cpu::DotOp op) { + AmxDotOpCandidate candidate; + if (isAmxCandidate(op, convertInt8, convertFp16, convertBf16, + candidate)) { + LLVM_DEBUG({ + LDBG("Found AMX candidate"); + LDBG(" Op: " << candidate.op); + LDBG(" LhsTileElemTy: " << candidate.lhsTileElemTy); + LDBG(" RhsTileElemTy: " << candidate.rhsTileElemTy); + LDBG(" AccTileElemTy: " << candidate.accTileElemTy); + LDBG(" TileM: " << candidate.tileM); + LDBG(" TileN: " << candidate.tileN); + LDBG(" TileK: " << candidate.tileK); + LDBG(" TilesInBlockM: " << candidate.tilesInBlockM); + LDBG(" TilesInBlockN: " << candidate.tilesInBlockN); + LDBG(" KeepAccOnTiles: " << candidate.keepAccOnTiles); + LDBG(" KeepAccInBuf: " << candidate.keepAccInBuf); + LDBG(" Has output buffer: " << !candidate.outBuf.empty()); + }); + candidates.push_back(candidate); + } + return WalkResult::advance(); + }); + + for (auto &candidate : candidates) { + LDBG("Starting conversion of candidate: " << candidate.op); + PatternRewriter rewriter(context); + rewriter.setInsertionPoint(candidate.op); + if (succeeded(convertCandidate(candidate, rewriter))) { + LDBG("Conversion succeeded!"); + } else { + LDBG("Conversion failed!"); + } + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotToAMX() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16) { + return std::make_unique(convertInt8, convertFp16, + convertBf16); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp new file mode 100644 index 000000000000..4d1832ca8cf9 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp @@ -0,0 +1,462 @@ +#include "ConvertDotCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTTOFMA +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This structure is used to hold candidates for conversion to FMA operations. +struct FmaDotOpCandidate { + // Operation to convert. + cpu::DotOp op; + // Here we keep actual element types used by LHS, RHS, and accumulator for + // computation. + Type lhsElemTy; + Type rhsElemTy; + Type accElemTy; + // Accumulator size. + int64_t accVecSize; + int64_t accRows; + // If accumulator is updated in a loop, then this flag indicates if we + // should keep it in registers the whole loop. + bool keepAccOnRegs = false; + // Memory buffer holding LHS. Can be empty if LHS is not a result of a + // simple load. + MemBuffer lhsBuf; + // Memory buffer holding RHS. Can be empty if RHS is not a result of a + // simple load. + MemBuffer rhsBuf; +}; + +// Check if input and output types can be handled by FMA (possibly, using +// additional casts for input/output). Returns true if FMA lowering is possible. +// In this case, element type fields of the candidate structure are filled +// with actual types to be used in lowering. +bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, + Type resElemTy, FmaDotOpCandidate &candidate) { + MLIRContext *ctx = lhsElemTy.getContext(); + if (lhsElemTy.isInteger() || rhsElemTy.isInteger() || resElemTy.isInteger()) { + LDBG("Drop candidate because int types are not supported."); + return false; + } + + // Find a type to use for computations. Here we assume FMA works on FP32 + // and FP64, so smaller types are promoted. Flags should be added to cover + // other cases. + Type commonInputElemTy; + if (lhsElemTy.isF64() || rhsElemTy.isF64() || resElemTy.isF64()) + commonInputElemTy = Float64Type::get(ctx); + else + commonInputElemTy = Float32Type::get(ctx); + + candidate.lhsElemTy = commonInputElemTy; + candidate.rhsElemTy = commonInputElemTy; + candidate.accElemTy = commonInputElemTy; + + return true; +} + +// Check input shapes. Currently, support only 2D cases and ignore small +// inputs. +bool checkInputShapes(VectorType lhsTy, VectorType resTy) { + if (lhsTy.getRank() != 2) + return false; + + if (resTy.getDimSize(1) < 8) + return false; + + return true; +} + +// Check if specified ContractionOp can be lowered to FMA operations. +// If conversion is possible, then true is returned and candidate +// structure is filled with detailed transformation info. +bool isFmaCandidate(cpu::DotOp op, FmaDotOpCandidate &candidate) { + MLIRContext *ctx = op.getContext(); + VectorType lhsTy = op.getA().getType(); + VectorType rhsTy = op.getB().getType(); + VectorType accTy = op.getC().getType(); + VectorType resTy = op.getType(); + + LDBG("Considering candidate op: " << op); + + // Check if input and output types match available hardware capabilities. + // If check is successful then effective element types are assigned to the + // candidate. + if (!checkElemTypes(lhsTy.getElementType(), rhsTy.getElementType(), + accTy.getElementType(), resTy.getElementType(), + candidate)) + return false; + + // Check input shapes. + if (!checkInputShapes(lhsTy, resTy)) + return false; + + candidate.op = op; + candidate.accVecSize = resTy.getDimSize(1); + candidate.accRows = resTy.getDimSize(0); + candidate.keepAccOnRegs = isLoopCarriedAcc(op.getC()); + + if (lhsTy.getElementType() == candidate.lhsElemTy) + candidate.lhsBuf = findInputBuffer(op.getA(), true); + if (rhsTy.getElementType() == candidate.rhsElemTy) + candidate.rhsBuf = findInputBuffer(op.getB(), false); + + return true; +} + +MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, + PatternRewriter &rewriter) { + LDBG("Storing vector to a temporary buffer: " << val); + auto vecTy = cast(val.getType()); + MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + rewriter.create(loc, val, buf.memRef, buf.indices); + return buf; +} + +SmallVector shiftIndices(Location loc, ArrayRef indices, + bool transposed, int64_t m, int64_t n, + PatternRewriter &rewriter) { + SmallVector res(indices.begin(), indices.end() - 2); + if (transposed) + std::swap(m, n); + res.push_back(shiftIndex(loc, *(indices.end() - 2), m, rewriter)); + res.push_back(shiftIndex(loc, *(indices.end() - 1), n, rewriter)); + return res; +} + +SmallVector shiftIndices(Location loc, const MemBuffer &buf, int64_t m, + int64_t n, PatternRewriter &rewriter) { + return shiftIndices(loc, buf.indices, buf.transposed, m, n, rewriter); +} + +Value loadRow(Location loc, VectorType resTy, const MemBuffer &buf, int64_t m, + PatternRewriter &rewriter) { + assert(!buf.empty()); + SmallVector indices = buf.indices; + indices[indices.size() - 2] = + shiftIndex(loc, indices[indices.size() - 2], m, rewriter); + return rewriter.create(loc, resTy, buf.memRef, indices); +} + +void storeRow(Location loc, const MemBuffer &buf, int64_t rowIdx, Value vec, + PatternRewriter &rewriter) { + SmallVector indices = buf.indices; + indices[indices.size() - 2] = + shiftIndex(loc, buf.indices[indices.size() - 2], rowIdx, rewriter); + rewriter.create(loc, vec, buf.memRef, indices); +} + +void storeRows(Location loc, const MemBuffer &buf, + const SmallVector &vecs, PatternRewriter &rewriter) { + SmallVector indices = buf.indices; + for (int64_t m = 0; m < vecs.size(); ++m) + storeRow(loc, buf, m, vecs[m], rewriter); +} + +SmallVector extractRows(Location loc, Value vec, + PatternRewriter &rewriter) { + VectorType vecTy = cast(vec.getType()); + SmallVector res; + for (int64_t m = 0; m < vecTy.getDimSize(0); ++m) { + auto row = + rewriter.create(loc, vec, SmallVector({m})); + res.push_back(row); + } + return res; +} + +Value mergeRows(Location loc, VectorType resTy, const SmallVector &tiles, + PatternRewriter &rewriter) { + Value res = + rewriter.create(loc, rewriter.getZeroAttr(resTy)); + for (int64_t m = 0; m < tiles.size(); ++m) + res = rewriter.create(loc, tiles[m], res, + SmallVector({m})); + return res; +} + +Value broadcastElem(Location loc, VectorType tileTy, const MemBuffer &buf, + int64_t m, int64_t n, PatternRewriter &rewriter) { + SmallVector indices = shiftIndices(loc, buf, m, n, rewriter); + Value scalar = rewriter.create(loc, buf.memRef, indices); + return rewriter.create(loc, tileTy, scalar); +} + +SmallVector computePrefetchIndices(Location loc, const MemBuffer &buf, + int64_t iters, + PatternRewriter &rewriter) { + SmallVector scaledStep; + Value itersVal; + for (auto step : buf.step) { + if (iters == 1) + scaledStep.push_back(rewriter.create( + loc, rewriter.getIndexType(), step)); + else if (auto cstOp = dyn_cast(step.getDefiningOp())) { + int64_t oldVal = cast(cstOp.getValue()).getInt(); + scaledStep.push_back( + rewriter.create(loc, oldVal * iters)); + } else { + if (!itersVal) + itersVal = + rewriter.create(loc, iters, step.getType()); + scaledStep.push_back(rewriter.create( + loc, rewriter.getIndexType(), + rewriter.create(loc, step.getType(), step, itersVal))); + } + } + + SmallVector res; + for (int64_t i = 0; i < scaledStep.size(); ++i) + res.push_back(rewriter.create( + loc, buf.indices[i].getType(), buf.indices[i], scaledStep[i])); + return res; +} + +void prefetch(Location loc, const MemBuffer &buf, int64_t m, int64_t n, + ArrayRef prefetchIndices, int64_t hint, + PatternRewriter &rewriter) { + SmallVector indices = + shiftIndices(loc, prefetchIndices, buf.transposed, m, n, rewriter); + rewriter.create(loc, buf.memRef, indices, false, hint, + true); +} + +LogicalResult convertCandidate(FmaDotOpCandidate &candidate, + PatternRewriter &rewriter) { + cpu::DotOp op = candidate.op; + Location loc = op.getLoc(); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getResult().getType()); + VectorType rhsVecTy = + VectorType::get(candidate.accVecSize, candidate.rhsElemTy); + VectorType accVecTy = + VectorType::get(candidate.accVecSize, candidate.accElemTy); + + Operation *allocaPoint = op; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Cast input data if required and prepare input buffer. It might be temporary + // buffers with stored vectors or the original input memory. + MemBuffer lhsBuf = candidate.lhsBuf; + if (lhsBuf.empty()) { + Value lhs = maybeCast(loc, op.getA(), candidate.lhsElemTy, rewriter); + lhsBuf = storeToTmpBuffer(loc, lhs, allocaPoint, rewriter); + } + + MemBuffer rhsBuf = candidate.rhsBuf; + if (rhsBuf.empty()) { + Value rhs = maybeCast(loc, op.getB(), candidate.rhsElemTy, rewriter); + rhsBuf = storeToTmpBuffer(loc, rhs, allocaPoint, rewriter); + } + + Value acc = maybeCast(loc, op.getC(), candidate.accElemTy, rewriter); + Value accToStore = acc; + scf::ForOp forOp; + if (candidate.keepAccOnRegs) { + forOp = cast(op->getParentOp()); + accToStore = getInitAccValue(acc); + } + + SmallVector accVecs; + SmallVector accInitVecs; + if (candidate.keepAccOnRegs) { + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + LDBG("Loading accumulator to tiles before the loop."); + accInitVecs = extractRows(loc, accToStore, rewriter); + accVecs = accInitVecs; + } else { + accVecs = extractRows(loc, acc, rewriter); + } + + // Compute indices to be used by prefetch. + int64_t lhsPrefetchIters = + std::max(int64_t(128) / lhsTy.getNumElements(), int64_t(1)); + auto lhsPrefetchIndices = + computePrefetchIndices(loc, candidate.lhsBuf, lhsPrefetchIters, rewriter); + int64_t rhsPrefetchIters = + std::max(int64_t(128) / rhsTy.getNumElements(), int64_t(1)); + auto rhsPrefetchIndices = + computePrefetchIndices(loc, candidate.rhsBuf, rhsPrefetchIters, rewriter); + Value nextRhsVec = loadRow(loc, rhsVecTy, rhsBuf, 0, rewriter); + for (int64_t k = 0; k < lhsTy.getDimSize(1); ++k) { + Value rhsVec = nextRhsVec; + + // Load next vector in advance to hide load latency. + if (k != lhsTy.getDimSize(1) - 1) + nextRhsVec = loadRow(loc, rhsVecTy, rhsBuf, k + 1, rewriter); + + // Prefetch RHS to LLC cache. + if (!rhsPrefetchIndices.empty()) + prefetch(loc, candidate.rhsBuf, k, 0, rhsPrefetchIndices, 1, rewriter); + + Value nextLhsBroadcasted = + broadcastElem(loc, accVecTy, lhsBuf, 0, k, rewriter); + for (int64_t m = 0; m < candidate.accRows; ++m) { + Value lhsBroadcasted = nextLhsBroadcasted; + + // Load next value in advance to hide load latency. + if (m != candidate.accRows - 1) + nextLhsBroadcasted = + broadcastElem(loc, accVecTy, lhsBuf, m + 1, k, rewriter); + + // Prefetch LHS to L1 cache. + if (!lhsPrefetchIndices.empty()) { + if ((candidate.lhsBuf.transposed && (m % 8 == 0)) || + (!candidate.lhsBuf.transposed && (k % 8 == 0))) + prefetch(loc, candidate.lhsBuf, m, k, lhsPrefetchIndices, 3, + rewriter); + } + + accVecs[m] = rewriter.create(loc, rhsVec, lhsBroadcasted, + accVecs[m]); + } + } + + if (candidate.keepAccOnRegs) { + // In this case we have the whole accumulator/result on tiles. Loop + // carried dependencies are not in place yet and should be added. + // After the loop, resulting tiles should either be stored to the + // output buffer, or moved to a vector through a temporary buffer. + + // We don't need the original accumulator and contraction op anymore. + // Directly yield orig accumulator value, so it would be later removed + // as unused. The original contraction can be removed right away. + int64_t origResIdx = op.getResult().getUses().begin()->getOperandNumber(); + rewriter.replaceOp(op, op.getC()); + + // Now, replace the loop with a new one to add loop carried dependency for + // accumulator tiles. + LDBG("Rewrite loop to introduce loop carried dependencies for accumulator " + "tiles."); + SmallVector newInitOperands; + SmallVector newYieldedValues; + for (int64_t m = 0; m < candidate.accRows; ++m) { + LDBG("Initial value\n " << accInitVecs[m] << "\nis combined with\n " + << accVecs[m]); + newInitOperands.push_back(accInitVecs[m]); + newYieldedValues.push_back(accVecs[m]); + } + auto newForOp = cast(*forOp.replaceWithAdditionalYields( + rewriter, newInitOperands, true, + [&newYieldedValues](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return newYieldedValues; + })); + + // The resulting tiles are now in the new loop results. + auto resVecs = newForOp.getResults().take_back(newYieldedValues.size()); + for (int64_t m = 0; m < candidate.accRows; ++m) + accVecs[m] = resVecs[m]; + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newForOp); + // Collect all results into a single vector. + LDBG("Merging resulting rows to replace loop result."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accElemTy); + Value newVal = mergeRows(loc, resTy, accVecs, rewriter); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(newForOp.getResult(origResIdx), newVal); + } else { + // The result is in the buffer. We should load it and replace the original + // constraction result. + LDBG("Merging resulting rows to replace orig op result."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accElemTy); + Value newVal = mergeRows(loc, resTy, accVecs, rewriter); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceOp(op, newVal); + } + + return success(); +} + +struct ConvertDotToFMA + : public triton::cpu::impl::ConvertDotToFMABase { + ConvertDotToFMA() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + SmallVector candidates; + mod->walk([this, &candidates](cpu::DotOp op) { + FmaDotOpCandidate candidate; + if (isFmaCandidate(op, candidate)) { + LLVM_DEBUG({ + LDBG("Found FMA candidate"); + LDBG(" Op: " << candidate.op); + LDBG(" LhsElemTy: " << candidate.lhsElemTy); + LDBG(" RhsElemTy: " << candidate.rhsElemTy); + LDBG(" AccElemTy: " << candidate.accElemTy); + LDBG(" AccVecSize: " << candidate.accVecSize); + LDBG(" AccRows: " << candidate.accRows); + LDBG(" KeepAccOnRegs: " << candidate.keepAccOnRegs); + if (!candidate.lhsBuf.empty()) { + LDBG(" LhsBuf: " << candidate.lhsBuf.memRef); + LDBG(" Transposed: " << candidate.lhsBuf.transposed); + } + if (!candidate.rhsBuf.empty()) { + LDBG(" RhsBuf: " << candidate.rhsBuf.memRef); + LDBG(" Transposed: " << candidate.rhsBuf.transposed); + } + }); + candidates.push_back(candidate); + } + return WalkResult::advance(); + }); + + for (auto &candidate : candidates) { + LDBG("Starting conversion of candidate: " << candidate.op); + PatternRewriter rewriter(context); + rewriter.setInsertionPoint(candidate.op); + if (succeeded(convertCandidate(candidate, rewriter))) { + LDBG("Conversion succeeded!"); + } else { + LDBG("Conversion failed!"); + } + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotToFMA() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp new file mode 100644 index 000000000000..da96eea967cd --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp @@ -0,0 +1,492 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "include/triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTPRODUCT +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// TODO: support SVE and different vector width +// We currently only supported Arm Neon (128 bit vector). +// To support scalable vectors in SVE, we need to generate +// vector-length agnostic (VLA) code using vector.vscale. +// To support other platform (AVX512 for X86), we need to +// change the vectorBitWidth and the intrinsics. +constexpr int vectorBitWidth = 128; + +// This function is used to identify bf16 dot product (expressed by elementwise +// multiplication follwed by a sum). +// For example, the following pattern: tl.sum(a * x[None, :], axis=1) +// is used to express a dot product. +// Since x is broadcated for the elementwise multiplication. And tl.sum will +// cast its bf16 input to fp32. +// The pattern in MLIR will be: +// BroadcastOp -> MulFOp -> ExtFOp -> MultiDimReductionOp +bool isBf16DotProduct(vector::MultiDimReductionOp op, bool useHorizontalSum, + Value &matInput, Value &vecInput, + PatternRewriter &rewriter) { + Value src = op.getSource(); + Value acc = op.getAcc(); + auto srcTy = cast(src.getType()); + auto accTy = cast(acc.getType()); + auto resTy = cast(op.getType()); + + auto srcRank = srcTy.getRank(); + auto outNum = srcTy.getDimSize(0); + + if (resTy != accTy || srcRank != 2 || !isFp32(srcTy)) + return false; + + if (op.isReducedDim(0) || !op.isReducedDim(1)) + return false; + + if (op.getKind() != vector::CombiningKind::ADD) + return false; + + auto extFOp = src.getDefiningOp(); + + if (!extFOp || !extFOp->hasOneUse()) + return false; + + auto mulFOp = extFOp.getIn().getDefiningOp(); + + if (!mulFOp || !mulFOp->hasOneUse()) + return false; + + Value lhs = mulFOp.getLhs(); + Value rhs = mulFOp.getRhs(); + + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + + if (!isBf16(lhsTy) || !isBf16(rhsTy)) + return false; + + const int lanes = + vectorBitWidth / lhsTy.getElementType().getIntOrFloatBitWidth(); + const int resultLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = lhsTy.getDimSize(1); + + if (outNum < 1) + return false; + + if (!useHorizontalSum) { + // TODO: masking is not currrently supported + if (outNum % resultLanes != 0) + return false; + } + + // TODO: masking is not currrently supported + if (kVal % lanes != 0) + return false; + + if (outNum == 1) { + matInput = lhs; + vecInput = rhs; + } else { + vector::BroadcastOp broadCastOp; + if (rhs.getDefiningOp()) { + matInput = lhs; + broadCastOp = rhs.getDefiningOp(); + } else { + matInput = rhs; + broadCastOp = lhs.getDefiningOp(); + } + if (!broadCastOp || !broadCastOp->hasOneUse()) + return false; + vecInput = broadCastOp.getSource(); + } + + if (cast(vecInput.getType()).getDimSize(0) != 1 || + cast(matInput.getType()).getDimSize(0) != outNum) + return false; + + return true; +} + +struct ConvertMulSumToDotHorizontalSum + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value acc = op.getAcc(); + auto resTy = cast(op.getType()); + + Value matInput; + Value vecInput; + + bool isMatch = isBf16DotProduct(op, /*useHorizontalSum=*/true, matInput, + vecInput, rewriter); + if (!isMatch) + return failure(); + + // Once we get the matrix input (NxK) and vector input (K), + // where N is the output channel dimension + // and K is the reduction dimension. + // We will generate the following code to perform the dot product. + // For each output channel: + // we will pull 8 bf16 elements from the vector and matrix each time when + // we iterate over the K dimension. + // We will then use bfdot to perform sum-of-products on pairs of + // bf16 elements, accumulate and get 4 fp32 outputs. + // After the iteration over the K dimension finishes, we will use a + // horizontal sum (faddv) to sum the 4 fp32 into a single fp32. + // We will also share the vector input across the output channels + // to reduce the number of loads. + // For example, if we dot product a size 2x16 matrix with a size 16 vector, + // the pseudo code will be: + // matrix = shapecast(matrix, 2x2x8) + // vector = shapecast(vector, 2x8) + // out = zeros(2x4, fp32) + // out[0] = bfdot(out[0], matrix[0][0], vector[0]) + // out[1] = bfdot(out[1], matrix[1][0], vector[0]) + // out[0] = bfdot(out[0], matrix[0][1], vector[1]) + // out[1] = bfdot(out[1], matrix[1][1], vector[1]) + // out_0 = faddv(out[0]) : 4xfp32 -> fp32 + // out_1 = faddv(out[1]) : 4xfp32 -> fp32 + + auto matInputTy = cast(matInput.getType()); + auto vecInputTy = cast(vecInput.getType()); + + const int lanes = + vectorBitWidth / matInputTy.getElementType().getIntOrFloatBitWidth(); + const int resultLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = matInputTy.getDimSize(1); + + // numOfOutputChannels is the number of output channels (N) + const int numOfOutputChannels = matInputTy.getDimSize(0); + // numOfBfdotOps is the number of bfdots needed for each output channel. + const int numOfBfdotOps = kVal / lanes; + + matInput = shapeCast(loc, matInput, + {numOfOutputChannels, numOfBfdotOps, lanes}, rewriter); + vecInput = shapeCast(loc, vecInput, {numOfBfdotOps, lanes}, rewriter); + + SmallVector outRes(numOfOutputChannels); + SmallVector mats(numOfOutputChannels); + + Type outResTy = VectorType::get(resultLanes, resTy.getElementType()); + + Value zeroRes = rewriter.create( + loc, outResTy, rewriter.getZeroAttr(outResTy)); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + outRes[outIdx] = zeroRes; + // Intermediate array to store each row of the input matrix. + mats[outIdx] = rewriter.create(loc, matInput, outIdx); + } + + SmallVector resultTypes = {outResTy}; + // TODO: this intrinsic is hard-coded for Arm Neon + auto bfdot = StringAttr::get(ctx, "llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + SmallVector args; + + for (int64_t idx = 0; idx < numOfBfdotOps; idx += 1) { + auto subVec = rewriter.create(loc, vecInput, idx); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + auto subMat = + rewriter.create(loc, mats[outIdx], idx); + args = {outRes[outIdx], subMat, subVec}; + // bfdot instruction: + // https://developer.arm.com/documentation/ddi0602/2024-06/SIMD-FP-Instructions/BFDOT--vector---BFloat16-floating-point-dot-product--vector-- + // LLVM fast math flags: + // https://llvm.org/docs/LangRef.html#fast-math-flags + // This bfdot intrinsic will perform an unfused sum-of-products of each + // pair of adjacent bf16 elements in the source vectors (8 bf16), and + // output 4 fp32 elements. + auto callIntrOp = rewriter.create( + loc, resultTypes, bfdot, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); + outRes[outIdx] = callIntrOp.getResult(0); + } + } + + Value res = rewriter.create(loc, resTy, + rewriter.getZeroAttr(resTy)); + + resultTypes = {resTy.getElementType()}; + // TODO: this intrinsic is hard-coded for Arm Neon + auto horzSum = StringAttr::get(ctx, "llvm.aarch64.neon.faddv.f32.v4f32"); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + args = {outRes[outIdx]}; + // This horizontal sum intrinsic will sum all fp32 elements in the source + // vector into a single fp32 element + auto callIntrOp = rewriter.create( + loc, resultTypes, horzSum, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); + res = rewriter.create(loc, callIntrOp.getResult(0), res, + outIdx); + } + + if (!isZeroConst(acc)) { + res = rewriter.create(loc, res, acc); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertMulSumToDotPack + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value acc = op.getAcc(); + auto resTy = cast(op.getType()); + + Value matInput; + Value vecInput; + + bool isMatch = isBf16DotProduct(op, /*useHorizontalSum=*/false, matInput, + vecInput, rewriter); + if (!isMatch) + return failure(); + + // Once we get the matrix input (NxK) and vector input (K), + // where N is the output channel dimension + // and K is the reduction dimension. + // We will generate the following code to perform the dot product. + // We will first transpose the matrix so that the output channel dimension + // is continuous, so we can store multiple output channels in one + // SIMD register. + // Then we will loop over the K dimension. + // For each iteration over K, we will pull 2 bf16 from the input vector. + // Inside the K loop, we will also iterate over the output channels. + // For each iteration over the output channel, we will pull + // 4 output channel (each containing 2 bf16). + // Then we will broadcast the 2 bf16 from the input vector, + // dot product it with the 4 output channels (each containing 2 bf16), + // and accumulate it with 4 outputs. + // We will iterate over N until all output channels are processed. + // Then we will move to the next 2 bf16 from the input vector (the K loop). + // We will also share the vector input across the output channels. + // For example, if we dot product a size 8x8 matrix with a size 8 vector, + // the generated pseudo code will be: + // Dimension: + // N: the output channel dimension + // n0: the number of SIMD registers needed to store the output + // -- N / 4 (2 in this case) + // n1: the number of outputs stored per SIMD register + // -- 4 + // K: the reduction dimension + // k0: the number of SIMD registers needed for the input vector + // -- K / 8 (1 in this case) + // k1: the number of lanes per SIMD register + // -- 4 + // k2: the number of bf16 elements per SIMD lane + // -- 2 + // matrix = shapecast(matrix, 8x4x2) + // shape: NxK -> Nx(k0xk1)xk2 + // matrix = tranpose(matrix, 1, 0, 2) : 8x4x2xbf16 -> 4x8x2xbf16 + // shape: Nx(k0xk1)xk2 -> (k0xk1)xNxk2 + // matrix = shapecast(matrix, 1x4x2x4x2xbf16) + // shape: (k0xk1)xNxk2 -> k0xk1xn0xn1xk2 + // vector = shapecast(vector, 1x4x2) + // shape: K -> k0xk1xk2 + // out = zeros(2x4, fp32) + // shape: n0xn1 + // subvec = broadcast(vector[0][0]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][0][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][0][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][1]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][1][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][1][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][2]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][2][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][2][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][3]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][3][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][3][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out = shapecast(out, 8) : 2x4xfp32 -> 8xfp32 + // shape: n0xn1 -> N + + auto matInputTy = cast(matInput.getType()); + auto vecInputTy = cast(vecInput.getType()); + + const int lanes = + vectorBitWidth / matInputTy.getElementType().getIntOrFloatBitWidth(); + const int resultLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = matInputTy.getDimSize(1); + + // numOfOutputChannels is the number of output channels (N) + const int numOfOutputChannels = matInputTy.getDimSize(0); + // numOfOutputRegs is the number of SIMD registers needed to store the + // output. + const int numOfOutputRegs = numOfOutputChannels / resultLanes; + // numOfVecRegs is the number of SIMD registers needed for the + // input vector. + const int numOfVecRegs = kVal / lanes; + // numOfVecPairs is the number of pairs (pair of bf16 elements) for the + // input vector. + const int numOfVecPairs = numOfVecRegs * resultLanes; + + VectorType fullResTy = + VectorType::get({numOfOutputRegs, resultLanes}, resTy.getElementType()); + + VectorType subResTy = VectorType::get(resultLanes, resTy.getElementType()); + + acc = shapeCast(loc, acc, fullResTy, rewriter); + + Type inElemTy = matInputTy.getElementType(); + // Integer type for a pair of bf16 elements + Type pairTy = IntegerType::get(ctx, 32); + + vecInput = + shapeCast(loc, vecInput, {numOfVecRegs, resultLanes, 2}, rewriter); + // We bitcast here because we are pulling pairs of bf16 each time. + vecInput = rewriter.create( + loc, VectorType::get({numOfVecRegs, resultLanes, 1}, pairTy), vecInput); + vecInput = shapeCast(loc, vecInput, {numOfVecRegs, resultLanes}, rewriter); + + matInput = shapeCast(loc, matInput, {numOfOutputChannels, numOfVecPairs, 2}, + rewriter); + // We bitcast here because we are pulling pairs of bf16 each time. + matInput = rewriter.create( + loc, VectorType::get({numOfOutputChannels, numOfVecPairs, 1}, pairTy), + matInput); + matInput = shapeCast(loc, matInput, {numOfOutputChannels, numOfVecPairs}, + rewriter); + // Packing/Transposing the weight matrix so that + // the output channel is continuous + matInput = rewriter.create( + loc, matInput, SmallVector{1, 0}); + matInput = shapeCast( + loc, matInput, + {numOfVecRegs, resultLanes, numOfOutputRegs, resultLanes}, rewriter); + + Value res = rewriter.create( + loc, fullResTy, rewriter.getZeroAttr(fullResTy)); + SmallVector resultTypes = {subResTy}; + // TODO: this intrinsic is hard-coded for Arm Neon + auto bfdot = StringAttr::get(ctx, "llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + SmallVector args; + + SmallVector subRes(numOfOutputRegs); + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + subRes[outIdx] = rewriter.create(loc, acc, outIdx); + } + for (int64_t idx = 0; idx < numOfVecRegs; idx += 1) { + Value fullVec = rewriter.create(loc, vecInput, idx); + for (int64_t vecIdx = 0; vecIdx < resultLanes; vecIdx += 1) { + // shuffle mask used to broadcast the 'vecIdx'th lane of fullVec + SmallVector shuffleMask(resultLanes, vecIdx); + // Broadcasting the 'vecIdx'th lane of fullVec + Value subVec = rewriter.create(loc, fullVec, fullVec, + shuffleMask); + subVec = rewriter.create( + loc, VectorType::get({lanes}, inElemTy), subVec); + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + Value subMat = rewriter.create( + loc, matInput, SmallVector{idx, vecIdx, outIdx}); + subMat = rewriter.create( + loc, VectorType::get({lanes}, inElemTy), subMat); + args = {subRes[outIdx], subMat, subVec}; + // bfdot instruction: + // https://developer.arm.com/documentation/ddi0602/2024-06/SIMD-FP-Instructions/BFDOT--vector---BFloat16-floating-point-dot-product--vector-- + // LLVM fast math flags: + // https://llvm.org/docs/LangRef.html#fast-math-flags + // This bfdot intrinsic will perform an unfused sum-of-products of + // each pair of adjacent bf16 elements in the source vectors + // (8 bf16), and output 4 fp32 elements. + auto callIntrOp = rewriter.create( + loc, resultTypes, bfdot, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); + subRes[outIdx] = callIntrOp.getResult(0); + } + } + } + + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + res = rewriter.create(loc, subRes[outIdx], res, outIdx); + } + + res = shapeCast(loc, res, resTy, rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertDotProduct + : public triton::cpu::impl::ConvertDotProductBase { + ConvertDotProduct() = default; + ConvertDotProduct(bool useHorizontalSum) { + this->useHorizontalSum = useHorizontalSum; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + + if (useHorizontalSum) { + patterns.add(context); + } else { + patterns.add(context); + } + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotProduct() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertDotProduct(bool useHorizontalSum) { + return std::make_unique(useHorizontalSum); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp new file mode 100644 index 000000000000..06ad1f1f6802 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -0,0 +1,464 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct ConvertBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + // TODO: support mixed-type ops? + if (!isAllBf16(op->getOperandTypes()) || !isAllBf16(op->getResultTypes())) + return failure(); + + Location loc = op.getLoc(); + OperationState newState(loc, OpT::getOperationName()); + // Convert operands to fp32 and generate fp32 op. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.create( + loc, toFp32(operand.getType()), operand); + newState.operands.push_back(newOperand); + } + newState.types = toFp32(op->getResultTypes()); + newState.attributes = op->getAttrs(); + auto newOp = rewriter.create(newState); + + // Convert op results back to Bf16 + SmallVector results; + for (auto res : llvm::enumerate(newOp->getResults())) + results.push_back(rewriter.create( + loc, op->getResult(res.index()).getType(), res.value())); + rewriter.replaceOp(op, results); + + return success(); + } + + bool isAllBf16(TypeRange types) const { + return std::all_of(types.begin(), types.end(), + [this](auto ty) { return isBf16(ty); }); + } + + SmallVector toFp32(TypeRange types) const { + SmallVector res; + for (auto ty : types) + res.push_back(::toFp32(ty)); + return res; + } +}; + +template +struct ConvertIToBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value fp32Val = + rewriter.create(loc, toFp32(op.getType()), op.getOperand()); + Value res = rewriter.create(loc, op.getType(), fp32Val); + rewriter.replaceOp(op, res); + return success(); + } +}; + +Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) { + MemRefType memRefTy = cast(memRef.getType()); + if (memRefTy.getElementType().isInteger()) + return memRef; + + Value res; + MemRefType newMemRefTy = + MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(), + memRefTy.getLayout(), memRefTy.getMemorySpace()); + auto insPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(memRef.getDefiningOp()); + // Memory references for masked operations and transfers are always built + // with PtrToMemRefOp, ExtractMemRefOp, or memref::AllocaOp. + if (auto castOp = memRef.getDefiningOp()) { + res = rewriter.create(memRef.getLoc(), newMemRefTy, + castOp.getSrc()); + } else if (auto extractOp = memRef.getDefiningOp()) { + res = rewriter.create(memRef.getLoc(), newMemRefTy, + extractOp.getSrc()); + } else { + auto allocaOp = memRef.getDefiningOp(); + assert(allocaOp && "Unexpected memref producer"); + res = rewriter.create(allocaOp.getLoc(), newMemRefTy, + allocaOp.getAlignmentAttr()); + rewriter.replaceOp(allocaOp, res); + } + rewriter.restoreInsertionPoint(insPoint); + return res; +} + +struct ConvertBf16MaskedLoadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value newPassThru = rewriter.create( + loc, toInt16(op.getPassThru().getType()), op.getPassThru()); + Value intVal = rewriter.create( + loc, toInt16(op.getType()), newBase, op.getIndices(), op.getMask(), + newPassThru); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16MaskedStoreOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedStoreOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getValueToStore().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getValueToStore().getType()), op.getValueToStore()); + rewriter.replaceOpWithNewOp( + op, newBase, op.getIndices(), op.getMask(), intVal); + return success(); + } +}; + +struct ConvertBf16TransferReadOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newSource = convertMemRefToI16(op.getSource(), rewriter); + Value newPadding = + op.getPadding() + ? rewriter.create( + loc, toInt16(op.getPadding().getType()), op.getPadding()) + : nullptr; + Value intVal = rewriter.create( + loc, toInt16(op.getType()), newSource, op.getIndices(), + op.getPermutationMapAttr(), newPadding, op.getMask(), + op.getInBoundsAttr()); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16TransferWriteOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getVector().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newSource = convertMemRefToI16(op.getSource(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getVector().getType()), op.getVector()); + rewriter.replaceOpWithNewOp( + op, intVal, newSource, op.getIndices(), op.getPermutationMapAttr(), + op.getMask(), op.getInBoundsAttr()); + return success(); + } +}; + +struct ConvertBf16LoadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newMemRef = convertMemRefToI16(op.getMemRef(), rewriter); + Value intVal = + rewriter.create(loc, newMemRef, op.getIndices()); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16StoreOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getValue().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newMemRef = convertMemRefToI16(op.getMemRef(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getValue().getType()), op.getValue()); + rewriter.replaceOpWithNewOp(op, intVal, newMemRef, + op.getIndices()); + return success(); + } +}; + +struct ConvertBf16Abs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType()) || !isBf16(op.getOperand().getType())) + return failure(); + + Location loc = op.getLoc(); + Value src = op.getOperand(); + Value intSrc = + rewriter.create(loc, toInt16(op.getType()), src); + TypedAttr maskAttr = rewriter.getI16IntegerAttr(0x7fff); + if (auto vecTy = dyn_cast(intSrc.getType())) + maskAttr = SplatElementsAttr::get(vecTy, maskAttr); + Value mask = rewriter.create(loc, maskAttr); + Value res = rewriter.create(loc, intSrc, mask); + res = rewriter.create(loc, op.getType(), res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertF8Abs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + if (!isFp8(op.getType()) || !isFp8(op.getOperand().getType())) + return failure(); + + Location loc = op.getLoc(); + Value src = op.getOperand(); + Type srcType = op.getType(); + + Value i8Src = op_bitcast(toInt8(srcType), src); + // Mask out the sign bit + Value nosign = op_and(i8Src, cst_like(i8Src, 0x7f)); + Value res = op_bitcast(srcType, nosign); + + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertMixedPrecisionMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value acc = op.getAcc(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + auto accTy = cast(acc.getType()); + auto resTy = cast(op.getType()); + + if (lhsTy.getElementType() == resTy.getElementType() && + rhsTy.getElementType() == resTy.getElementType() && + accTy.getElementType() == resTy.getElementType()) + return failure(); + + Type commonElemTy = resTy.getElementType(); + if (lhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = lhsTy; + if (rhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = rhsTy; + if (accTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = accTy; + + lhs = castElemTy(loc, lhs, commonElemTy, rewriter); + rhs = castElemTy(loc, rhs, commonElemTy, rewriter); + acc = castElemTy(loc, acc, commonElemTy, rewriter); + + Value newRes = rewriter.create( + loc, lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes()); + newRes = castElemTy(loc, newRes, resTy.getElementType(), rewriter); + + rewriter.replaceOp(op, newRes); + return success(); + } + + Value castElemTy(Location loc, Value val, Type elemTy, + PatternRewriter &rewriter) const { + auto valTy = cast(val.getType()); + if (valTy.getElementType() == elemTy) + return val; + + auto resTy = toTyOrVectorOf(valTy, elemTy); + if (valTy.getElementType().isInteger()) { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } else { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } + } +}; + +template struct PromoteOpToFp32 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PromoteOpToFp32(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Type opTy = op.getType(); + + if (!isFp8(opTy) && !isFp16(opTy) && !isBf16(opTy)) + return failure(); + + Type fp32Ty = toFp32(opTy); + SmallVector fp32Ops; + for (auto operand : op->getOperands()) + fp32Ops.push_back(rewriter.create(loc, fp32Ty, operand)); + auto newOp = rewriter.create(loc, fp32Ty, fp32Ops); + rewriter.replaceOpWithNewOp(op, opTy, newOp); + return success(); + } +}; + +struct ConvertUnsupportedOps + : public triton::cpu::impl::ConvertUnsupportedOpsBase< + ConvertUnsupportedOps> { + ConvertUnsupportedOps() = default; + + ConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32) { + this->promoteBf16ToFp32 = promoteBf16ToFp32; + this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul; + this->promoteLibMathToFp32 = promoteLibMathToFp32; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + if (promoteBf16ToFp32) { + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + } + patterns.add(context); + if (convertMixedPrecisionMatmul) { + patterns.add(context); + } + if (promoteLibMathToFp32) { + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + } + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertUnsupportedOps() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32) { + return std::make_unique( + promoteBf16ToFp32, convertMixedPrecisionMatmul, promoteLibMathToFp32); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp new file mode 100644 index 000000000000..4a4c8bd8e448 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -0,0 +1,547 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_DECOMPOSEFPCONVERSIONS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +struct Fp32ToBf16Conversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + Value src = op.getIn(); + if (!isBf16(op.getType()) || !isFp32(src.getType())) + return failure(); + + Location loc = op.getLoc(); + Value i32Src = op_bitcast(toInt32(src.getType()), src); + Value shiftedSrc = op_lshr(i32Src, cst_like(i32Src, 16)); + Value i16Res = op_trunci(toInt16(src.getType()), shiftedSrc); + Value res = op_bitcast(op.getType(), i16Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct Bf16ToFp32Conversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const override { + Value src = op.getIn(); + if (!isFp32(op.getType()) || !isBf16(src.getType())) + return failure(); + + Location loc = op.getLoc(); + Value i16Src = op_bitcast(toInt16(src.getType()), src); + Value i32Src = op_zext(toInt32(src.getType()), i16Src); + Value i32Res = op_shl(i32Src, cst_like(i32Src, 16)); + Value res = op_bitcast(op.getType(), i32Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +typedef std::function FpToFpConvFn; + +// Convert FP8 to FP16/FP32. +Value convertFp8(Location loc, Value src, int srcExpBits, int srcExpBias, + Type dstFpTy, PatternRewriter &rewriter) { + assert(srcExpBits >= 4 && srcExpBits <= 5 && "Unexpect FP8 type conversion"); + assert(srcExpBias >= 0 && srcExpBias <= 16 && "Unexpect FP8 type conversion"); + assert((dstFpTy.isF16() || dstFpTy.isF32()) && + "Unsupported FP8 type conversion"); + Type srcTy = src.getType(); + Type dstTy = toTyOrVectorOf(srcTy, dstFpTy); + int dstExpBits = dstFpTy.isF16() ? 5 : 8; + int dstMantBits = dstFpTy.isF16() ? 10 : 23; + int dstExpBias = dstFpTy.isF16() ? 15 : 127; + int srcMantBits = 7 - srcExpBits; + assert(dstExpBias >= srcExpBias && "Unsupported FP8 type conversion"); + Type dstIntTy = + dstFpTy.isF16() ? rewriter.getI16Type() : rewriter.getI32Type(); + Value i8Src = op_bitcast(toInt8(srcTy), src); + Value intSrc = op_zext(toTyOrVectorOf(srcTy, dstIntTy), i8Src); + Value shiftedVal; + if (srcExpBits != dstExpBits) { + Value sign = op_and(intSrc, cst_like(intSrc, 0x80)); + Value nosign = op_and(intSrc, cst_like(intSrc, 0x7f)); + shiftedVal = op_addi( + op_shl(sign, cst_like(sign, dstFpTy.getIntOrFloatBitWidth() - 8)), + op_shl(nosign, cst_like(nosign, dstMantBits - srcMantBits))); + } else { + shiftedVal = + op_shl(intSrc, cst_like(intSrc, dstFpTy.getIntOrFloatBitWidth() - 8)); + } + Value res = op_bitcast(dstTy, shiftedVal); + if (srcExpBias != dstExpBias) { + double scale = pow(2, dstExpBias - srcExpBias); + res = op_mulf(res, cst_like(res, scale)); + } + return res; +} + +Value convertFp8E4M3ToFp16(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 4, 7, rewriter.getF16Type(), rewriter); +} + +Value convertFp8E5M2ToFp16(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 15, rewriter.getF16Type(), rewriter); +} + +Value convertFp8E5M2B16ToFp16(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toFp16(src.getType()), f32Res); +} + +Value convertFp8E4M3ToBf16(Location loc, Value src, PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 4, 7, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E5M2ToBf16(Location loc, Value src, PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 15, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E5M2B16ToBf16(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E4M3ToFp32(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 4, 7, rewriter.getF32Type(), rewriter); +} + +Value convertFp8E5M2ToFp32(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 15, rewriter.getF32Type(), rewriter); +} + +Value convertFp8E5M2B16ToFp32(Location loc, Value src, + PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); +} + +// Convert F16/FP32 to FP8. +Value convertToFp8(Location loc, Value src, Type dstFpTy, int dstExpBits, + int dstExpBias, bool rtneRounding, bool unsignedZero, + PatternRewriter &rewriter) { + assert(dstExpBits >= 4 && dstExpBits <= 5 && "Unexpect FP8 type conversion"); + assert(dstExpBias >= 0 && dstExpBias <= 16 && "Unexpect FP8 type conversion"); + Type srcTy = src.getType(); + Type srcFpTy = getElemTyOrTy(srcTy); + assert((srcFpTy.isF16() || srcFpTy.isF32()) && + "Unsupported FP8 type conversion"); + int dstMantBits = 7 - dstExpBits; + int srcExpBits = srcFpTy.isF16() ? 5 : 8; + int srcMantBits = srcFpTy.isF16() ? 10 : 23; + int srcExpBias = srcFpTy.isF16() ? 15 : 127; + assert(dstExpBias <= srcExpBias && "Unsupported FP8 type conversion"); + Type srcIntTy = + srcFpTy.isF16() ? rewriter.getI16Type() : rewriter.getI32Type(); + Value intSrc = op_bitcast(toTyOrVectorOf(srcTy, srcIntTy), src); + // Extract sign and put it to the proper place for FP8. + Value sign = + op_lshr(op_and(intSrc, cst_like(intSrc, 1 << (srcExpBits + srcMantBits))), + cst_like(intSrc, srcFpTy.getIntOrFloatBitWidth() - 8)); + // Extract mantissa and exponent. + Value mant = op_and(intSrc, cst_like(intSrc, (1 << srcMantBits) - 1)); + Value exp = op_and(op_lshr(intSrc, cst_like(intSrc, srcMantBits)), + cst_like(intSrc, (1 << srcExpBits) - 1)); + Value isZeroExp = op_icmp_eq(exp, cst_like(exp, 0)); + mant = op_select(isZeroExp, mant, + op_addi(mant, cst_like(mant, 1 << srcMantBits))); + exp = op_select(isZeroExp, exp, op_subi(exp, cst_like(exp, 1))); + double adjustment = pow(0.5, srcMantBits - dstMantBits); + exp = op_subi(exp, cst_like(exp, srcExpBias - dstExpBias)); + mant = op_mulf(op_sitofp(srcTy, mant), cst_like(src, adjustment)); + // Make exponent non-negative. + if (dstExpBias - srcExpBias <= -8) { + // In this case we don't have enough mantissa bits, so can round to 0. + Value mask = op_icmp_sgt(exp, cst_like(exp, -8)); + exp = op_select(mask, exp, cst_like(exp, 0)); + mant = op_select(mask, mant, cst_like(mant, 0.0)); + } + if (dstExpBias - srcExpBias <= -4) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -4)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 4))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.0625))); + } + if (dstExpBias - srcExpBias <= -2) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -2)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 2))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.25))); + } + if (dstExpBias - srcExpBias <= -1) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -1)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 1))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.5))); + } + if (rtneRounding) { + // Bring the value to the range [2 ** 10/23, 2 ** 11/24] + // where the representable fp16/fp32 map exactly to integers. + // Addition has RTNE semantics. + Value offs = cst_like(mant, static_cast(1 << srcMantBits)); + mant = op_addf(mant, offs); + mant = op_subf(mant, offs); + } + mant = op_fptosi(toTyOrVectorOf(srcTy, srcIntTy), mant); + + Value res = + op_addi(sign, op_addi(op_shl(exp, cst_like(exp, 7 - dstExpBits)), mant)); + res = op_trunci(toInt8(srcTy), res); + if (unsignedZero) { + Value isNegativeZero = op_icmp_eq(res, cst_like(res, 0x80)); + res = op_select(isNegativeZero, cst_like(res, 0), res); + } + res = op_bitcast(toTyOrVectorOf(srcTy, dstFpTy), res); + return res; +} + +Value convertFp16ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + false, rewriter); +} + +Value convertFp16ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + false, rewriter); +} + +Value convertFp16ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type dstTy = toFp8E5M2(srcTy); + Value i16Src = op_bitcast(toInt16(srcTy), src); + Value shiftedSrc = op_lshr(i16Src, cst_like(i16Src, 8)); + Value i8Res = op_trunci(toInt8(srcTy), shiftedSrc); + Value res = op_bitcast(dstTy, i8Res); + return res; +} + +Value convertFp16ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type dstTy = toFp8E5M2(srcTy); + Value i16Src = op_bitcast(toInt16(srcTy), src); + Value sign = op_and(i16Src, cst_like(i16Src, 0x8000)); + Value truncated = op_and(i16Src, cst_like(i16Src, 0x7f00)); + Value tail = op_and(i16Src, cst_like(i16Src, 0xff)); + Value odd_trunc = op_icmp_ne(op_and(truncated, cst_like(truncated, 0x100)), + cst_like(truncated, 0)); + Value round_up = + op_or(op_icmp_ugt(tail, cst_like(tail, 0x80)), + op_and(op_icmp_eq(tail, cst_like(tail, 0x80)), odd_trunc)); + // Skip round-up if it leads to inf/nan. + round_up = + op_and(round_up, op_icmp_ult(truncated, cst_like(truncated, 0x7b00))); + truncated = op_select( + round_up, op_addi(truncated, cst_like(truncated, 0x100)), truncated); + + Value res = op_lshr(op_or(truncated, sign), cst_like(truncated, 8)); + res = op_bitcast(dstTy, op_trunci(toInt8(srcTy), res)); + return res; +} + +Value convertFp16ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + false, true, rewriter); +} + +Value convertFp16ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + true, true, rewriter); +} + +Value convertBf16ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + false, rewriter); +} + +Value convertBf16ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + false, rewriter); +} + +Value convertBf16ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, false, + false, rewriter); +} + +Value convertBf16ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, true, + false, rewriter); +} + +Value convertBf16ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + false, true, rewriter); +} + +Value convertBf16ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + true, true, rewriter); +} + +Value convertFp32ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + false, rewriter); +} + +Value convertFp32ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + false, rewriter); +} + +Value convertFp32ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, false, + false, rewriter); +} + +Value convertFp32ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, true, + false, rewriter); +} + +Value convertFp32ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, false, + true, rewriter); +} + +Value convertFp32ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, true, + true, rewriter); +} + +FpToFpConvFn +getFpToFpConversionFn(Type srcTy, Type dstTy, + std::optional roundMode) { + auto F8E4M3TyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F8E5M2B16TyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + + static DenseMap, FpToFpConvFn> fpExtFnMap = { + {{F8E4M3TyID, F16TyID}, convertFp8E4M3ToFp16}, + {{F8E5M2TyID, F16TyID}, convertFp8E5M2ToFp16}, + {{F8E5M2B16TyID, F16TyID}, convertFp8E5M2B16ToFp16}, + {{F8E4M3TyID, BF16TyID}, convertFp8E4M3ToBf16}, + {{F8E5M2TyID, BF16TyID}, convertFp8E5M2ToBf16}, + {{F8E5M2B16TyID, BF16TyID}, convertFp8E5M2B16ToBf16}, + {{F8E4M3TyID, F32TyID}, convertFp8E4M3ToFp32}, + {{F8E5M2TyID, F32TyID}, convertFp8E5M2ToFp32}, + {{F8E5M2B16TyID, F32TyID}, convertFp8E5M2B16ToFp32}, + }; + static DenseMap, FpToFpConvFn> + fpTruncFnMap = { + {{F16TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E4M3Rtz}, + {{F16TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E4M3Rtne}, + {{F16TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E5M2Rtz}, + {{F16TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E5M2Rtne}, + {{F16TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E5M2B16Rtz}, + {{F16TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E5M2B16Rtne}, + {{BF16TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E4M3Rtz}, + {{BF16TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E4M3Rtne}, + {{BF16TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E5M2Rtz}, + {{BF16TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E5M2Rtne}, + {{BF16TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E5M2B16Rtz}, + {{BF16TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E5M2B16Rtne}, + {{F32TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E4M3Rtz}, + {{F32TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E4M3Rtne}, + {{F32TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E5M2Rtz}, + {{F32TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E5M2Rtne}, + {{F32TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E5M2B16Rtz}, + {{F32TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E5M2B16Rtne}, + }; + + if (roundMode) { + auto key = + std::make_tuple(srcTy.getTypeID(), dstTy.getTypeID(), *roundMode); + if (fpTruncFnMap.count(key)) + return fpTruncFnMap.at(key); + } else { + auto key = std::make_tuple(srcTy.getTypeID(), dstTy.getTypeID()); + if (fpExtFnMap.count(key)) + return fpExtFnMap.at(key); + } + + return FpToFpConvFn(); +} + +Value convertFpToFp(Location loc, Value src, Type dstTy, + std::optional roundMode, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type srcElemTy = getElemTyOrTy(srcTy); + Type dstElemTy = getElemTyOrTy(dstTy); + auto fn = getFpToFpConversionFn(srcElemTy, dstElemTy, roundMode); + if (!fn) { + llvm::errs() << "Unsupported conversion from " << srcElemTy << " to " + << dstElemTy; + if (roundMode) + llvm::errs() << " with rounding mode " + << arith::stringifyRoundingMode(*roundMode); + llvm::errs() << "\n"; + llvm_unreachable(""); + } + return fn(loc, src, rewriter); +} + +struct RewriteTruncFp8 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value src = op.getIn(); + Type srcTy = src.getType(); + Type dstTy = op.getType(); + if (!isFp8(dstTy)) + return failure(); + Value res = convertFpToFp(loc, src, dstTy, op.getRoundingmode(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct RewriteExtFp8 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value src = op.getIn(); + Type srcTy = src.getType(); + if (!isFp8(srcTy)) + return failure(); + Type dstTy = op.getType(); + Value res = convertFpToFp(loc, src, dstTy, std::nullopt, rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct DecomposeFpConversions + : public triton::cpu::impl::DecomposeFpConversionsBase< + DecomposeFpConversions> { + DecomposeFpConversions() = default; + + DecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + this->decomposeBf16Conversions = decomposeBf16Conversions; + this->decomposeFp8Conversions = decomposeFp8Conversions; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + if (decomposeBf16Conversions) { + patterns.add(context); + patterns.add(context); + } + if (decomposeFp8Conversions) { + patterns.add(context); + patterns.add(context); + } + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createDecomposeFpConversions() { + return std::make_unique(); +} + +std::unique_ptr> +createDecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + return std::make_unique(decomposeBf16Conversions, + decomposeFp8Conversions); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp new file mode 100644 index 000000000000..e747ef16c957 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -0,0 +1,374 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_OPTIMIZEMASKS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +int64_t getDivisibility(Value val) { + BlockArgument blockArg = dyn_cast(val); + if (!blockArg) + return 1; + + Operation *argOp = blockArg.getOwner()->getParentOp(); + if (auto fn = dyn_cast(argOp)) { + Attribute attr = fn.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if (auto iattr = dyn_cast_or_null(attr)) { + return iattr.getInt(); + } + } + + return 1; +} + +bool isAlwaysDivisible(Value val, int64_t divisor) { + if (auto cst = val.getDefiningOp()) { + auto intAttr = dyn_cast(cst.getValue()); + return intAttr && (intAttr.getInt() % divisor == 0); + } + return getDivisibility(val) % divisor == 0; +} + +bool isAlwaysDivisible(Value val, Value divisor) { + if (auto cst = divisor.getDefiningOp()) { + auto intAttr = dyn_cast(cst.getValue()); + if (intAttr) + return isAlwaysDivisible(val, intAttr.getInt()); + } + return false; +} + +// Optimize cdiv pattern using divisibility hints. If value is known to be +// divisible by N then we can transform +// (val + K - 1) / K +// to +// val / K +// if N % K == 0 and val is not negative. Usually, we cannot prove value to be +// non-negative but still can apply transformation for contexts that assume +// positive value (e.g. as an upper bound in a for-loop with non-negative +// lower bound). +struct CdivToDiv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Looking for a scalar op only. + if (isa(op.getType())) + return failure(); + + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + auto addOpDef = lhs.getDefiningOp(); + auto divisorDef = rhs.getDefiningOp(); + if (!addOpDef || !divisorDef) + return failure(); + + arith::ConstantOp addCstDef; + Value addOtherVal; + if ((addCstDef = addOpDef.getLhs().getDefiningOp())) + addOtherVal = addOpDef.getRhs(); + else if ((addCstDef = addOpDef.getRhs().getDefiningOp())) + addOtherVal = addOpDef.getLhs(); + else + return failure(); + + int64_t divisorCst = cast(divisorDef.getValue()).getInt(); + int64_t addCst = cast(addCstDef.getValue()).getInt(); + if (divisorCst <= addCst) + return failure(); + + if (!isAlwaysDivisible(addOtherVal, divisorCst)) + return failure(); + + Value res = op.getResult(); + Value newRes = + rewriter.create(loc, addOtherVal, divisorDef); + int replaced = 0; + rewriter.replaceUsesWithIf(res, newRes, [&](OpOperand &use) { + if (auto forOp = dyn_cast(use.getOwner())) { + auto lowerDef = + forOp.getLowerBound().getDefiningOp(); + if (lowerDef && use.getOperandNumber() == 1 && + cast(lowerDef.getValue()).getInt() >= 0) { + ++replaced; + return true; + } + } + return false; + }); + + if (!replaced) { + rewriter.eraseOp(newRes.getDefiningOp()); + return failure(); + } + + return success(); + } +}; + +// This pattern rewrites for-loops used for tiling to optimize out division +// and multiplication using divisibility hints. +// Typical tiled loop looks like: +// for i in range(0, tl.cdiv(size, TILE_SIZE)): +// offs = i * TILE_SIZE +// ... +// If size is known to be divisible by TILE_SIZE then it can be written as: +// for offs in range(0, size, TILE_SIZE): +// ... +// This pattern is used after an attempt to replace cdiv with a regular +// division. Possible input pattern is: +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c16 = arith.constant 16 : index +// %init = arith.constant dense<0x00000000> : vector<16xf32> +// %1 = arith.divsi %arg4, %c16 +// %2 = scf.for %arg5 = %c0 to %1 step %c1 iter_args(%arg6 = %init) -> +// (vector<16xf32>) : i32 { +// %3 = arith.muli %arg5, %c16 : i32 +// ... +// } +// where %arg4 is known to be divisible by 16. The resulting code would be: +// %c0 = arith.constant 0 : index +// %c16 = arith.constant 16 : index +// %init = arith.constant dense<0x00000000> : vector<16xf32> +// %2 = scf.for %arg5 = %c0 to %arg4 step %c16 iter_args(%arg6 = %init) -> +// (vector<16xf32>) : i32 { +// ... +// } +// This removes division and simplifies the following analysis to optimize +// masked memory acccess for the tile. +struct ScaleInductionVariable : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value iv = op.getInductionVar(); + Value lower = op.getLowerBound(); + Value upper = op.getUpperBound(); + Value step = op.getStep(); + auto lowerDef = lower.getDefiningOp(); + auto upperDef = upper.getDefiningOp(); + if (!lowerDef || !upperDef) + return failure(); + + int64_t lowerVal = cast(lowerDef.getValue()).getInt(); + if (lowerVal < 0) + return failure(); + + // TODO: This is a strong requirement. With more generic value range + // analysis we should be able to not rely on this transformation. + if (!iv.hasOneUse()) + return failure(); + + auto ivUse = dyn_cast(*iv.getUsers().begin()); + if (!ivUse) + return failure(); + + Value scale = ivUse.getLhs() == iv ? ivUse.getRhs() : ivUse.getLhs(); + auto scaleDef = scale.getDefiningOp(); + auto divRhsDef = upperDef.getRhs().getDefiningOp(); + auto divLhs = upperDef.getLhs(); + if (!scaleDef || !divRhsDef) + return failure(); + + int64_t scaleVal = cast(scaleDef.getValue()).getInt(); + int64_t divisorVal = cast(divRhsDef.getValue()).getInt(); + if (scaleVal != divisorVal || !isAlwaysDivisible(divLhs, scaleVal) || + lowerVal % scaleVal != 0) + return failure(); + + // Build new lower bound. + Value newLower = lower; + if (lowerVal != 0) { + rewriter.setInsertionPointAfterValue(lower); + newLower = rewriter.create( + lower.getLoc(), lowerVal * scaleVal, lower.getType()); + } + // New Upper bound. + Value newUpper = divLhs; + // Build new step. + rewriter.setInsertionPoint(op); + auto newStep = rewriter.create(ivUse.getLoc(), step, scale); + + // Modify ForOp. + rewriter.startOpModification(op); + op.setLowerBound(newLower); + op.setUpperBound(newUpper); + op.setStep(newStep); + rewriter.finalizeOpModification(op); + + // Replace iv uses. + rewriter.replaceAllUsesWith(ivUse, iv); + + return success(); + } +}; + +// Build affine expression to express min/max value of the given SSA name. +// symbolTable is used to map SSA names to affine symbols. +AffineExpr buildMinOrMaxExpr(Value val, bool isSigned, bool isMax, + llvm::DenseMap &symbolTable) { + if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getInput(), isSigned, isMax, symbolTable); + } else if (auto def = val.getDefiningOp()) { + auto attr = def.getValueAttr(); + if (auto intAttr = dyn_cast(attr)) + return getAffineConstantExpr(intAttr.getInt(), val.getContext()); + if (auto denseAttr = dyn_cast(attr)) { + auto valueBegin = denseAttr.value_begin(); + auto valueEnd = denseAttr.value_end(); + auto cmpVals = [isSigned](const APInt &lhs, const APInt &rhs) { + return isSigned ? lhs.slt(rhs) : lhs.ult(rhs); + }; + auto valueIt = isMax ? std::max_element(valueBegin, valueEnd, cmpVals) + : std::min_element(valueBegin, valueEnd, cmpVals); + return getAffineConstantExpr((*valueIt).getSExtValue(), val.getContext()); + } + } else if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getLhs(), isSigned, isMax, symbolTable) + + buildMinOrMaxExpr(def.getRhs(), isSigned, isMax, symbolTable); + } else if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getLhs(), isSigned, isMax, symbolTable) - + buildMinOrMaxExpr(def.getRhs(), isSigned, !isMax, symbolTable); + } else if (auto blockArg = dyn_cast(val)) { + auto op = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(op)) { + if (val == forOp.getInductionVar()) { + Value lower = forOp.getLowerBound(); + Value upper = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // For min value return lower bound. + if (!isMax) + return buildMinOrMaxExpr(forOp.getLowerBound(), isSigned, isMax, + symbolTable); + + // For max value we use upper bound - 1 in generic case and bound - step + // if both bounds are divisible by the step. + if (isAlwaysDivisible(lower, step) && isAlwaysDivisible(upper, step)) { + return buildMinOrMaxExpr(upper, isSigned, isMax, symbolTable) - + buildMinOrMaxExpr(step, isSigned, false, symbolTable); + } + return buildMinOrMaxExpr(upper, isSigned, isMax, symbolTable) - + getAffineConstantExpr(1, val.getContext()); + } + } + } + + if (symbolTable.count(val)) + return getAffineSymbolExpr(symbolTable.at(val), val.getContext()); + + unsigned pos = symbolTable.size(); + symbolTable.insert(std::make_pair(val, pos)); + return getAffineSymbolExpr(pos, val.getContext()); +} + +// Check if vector mask is all-ones by checking compared values ranges. +// Only simplest cases are covered here, so affine expression is used +// to represent a range for now. +bool isAlwaysAllOnes(arith::CmpIOp maskDef) { + auto pred = maskDef.getPredicate(); + if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) + return false; + + bool isSigned = + pred == arith::CmpIPredicate::sgt || pred == arith::CmpIPredicate::sge || + pred == arith::CmpIPredicate::sle || pred == arith::CmpIPredicate::slt; + llvm::DenseMap symbolTable; + AffineExpr maxOffs; + AffineExpr minLen; + if (pred == arith::CmpIPredicate::slt || pred == arith::CmpIPredicate::sle || + pred == arith::CmpIPredicate::ult || pred == arith::CmpIPredicate::ule) { + maxOffs = buildMinOrMaxExpr(maskDef.getLhs(), isSigned, true, symbolTable); + minLen = buildMinOrMaxExpr(maskDef.getRhs(), isSigned, false, symbolTable); + } else { + maxOffs = buildMinOrMaxExpr(maskDef.getRhs(), isSigned, true, symbolTable); + minLen = buildMinOrMaxExpr(maskDef.getLhs(), isSigned, false, symbolTable); + } + + // The mask is all-ones if max offset is always less than min length. + auto diff = maxOffs - minLen; + if (auto diffCst = dyn_cast(diff)) { + int64_t diffVal = diffCst.getValue(); + if (pred == arith::CmpIPredicate::slt || + pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::sgt || pred == arith::CmpIPredicate::ugt) + return diffVal < 0; + else + return diffVal <= 0; + } + + return false; +} + +struct OptimizeMask : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::CmpIOp op, + PatternRewriter &rewriter) const override { + if (!isAlwaysAllOnes(op)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), rewriter.getOneAttr(op.getType())); + return success(); + } +}; + +struct OptimizeMasks + : public triton::cpu::impl::OptimizeMasksBase { + OptimizeMasks() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + // TODO: This pass optimizes out masks applying a set of very strict + // patterns. We should use more generic range and divisibility analysis + // to cover more cases and remove dependency on other transformations. + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + + // TODO: if masks removal failed for loads/stores in a for-loop, we might + // still optimize it using loop peeling. + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createOptimizeMasks() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonRaiseBlockPointer/CMakeLists.txt b/third_party/cpu/lib/TritonRaiseBlockPointer/CMakeLists.txt new file mode 100644 index 000000000000..f9d88d6d2007 --- /dev/null +++ b/third_party/cpu/lib/TritonRaiseBlockPointer/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonRaiseBlockPointer + TritonRaiseBlockPointer.cpp + + DEPENDS + TritonRaiseBlockPointerPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR +) diff --git a/third_party/cpu/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/cpu/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp new file mode 100644 index 000000000000..71e83155e969 --- /dev/null +++ b/third_party/cpu/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -0,0 +1,1446 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/TritonRaiseBlockPointer/Passes.h" + +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "triton-cpu-raise-block-pointer" + +// This pass does manage to raise tensor of pointers into block pointers for +// simple cases (e.g. 03 matmul tutorial). However, this pass has several know +// limitations: +// - Masks and modulos are not correctly handled by this pass. Issue #1784 +// (https://github.com/intel/intel-xpu-backend-for-triton/issues/1784) has +// been created to address this limitation. +// - The pattern matching method used in this pass makes it prone to fail +// raising memory accesses. For the moment, the most fragile part of the pass +// is probably the support for fixing the axis of the offsets +// (see comment l.867). + +using namespace mlir; + +namespace mlir::triton::cpu { +#define GEN_PASS_DEF_TRITONRAISEBLOCKPOINTER +#include "cpu/include/TritonRaiseBlockPointer/Passes.h.inc" +} // namespace mlir::triton::cpu + +namespace { +constexpr unsigned offsetBitwidth = 32; +constexpr unsigned shapeAndStridesBitwidth = 64; + +// FROM intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +static std::optional getIntAttr(const OpFoldResult ofr) { + if (ofr.is() && isa(ofr.get())) + return cast(ofr.get()).getInt(); + return std::nullopt; +} + +// FROM intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +std::optional getFoldedConstantValue(Operation *op) { + SmallVector results; + if (failed(op->fold(results))) { + return std::nullopt; + } + + // If fold succeeded but `results` is empty, we give a second try, after the + // operands have been switched during the first call to `fold()`. + if (results.empty()) { + if (failed(op->fold(results))) { + return std::nullopt; + } + } + + if (results.size() != 1) { + return std::nullopt; + } + + auto intAttr = getIntAttr(results[0]); + if (intAttr.has_value()) { + return intAttr.value(); + } + + auto val = cast(results[0]); + auto constOp = val.getDefiningOp(); + if (!constOp) + return std::nullopt; + + return getIntAttr(constOp.getValue()); +} + +// FROM intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +bool isConstant(Value val, const unsigned expected) { + auto defOp = val.getDefiningOp(); + if (!defOp) + return false; + return (getFoldedConstantValue(defOp) == expected); +} + +// Data structure used to decode pointer arithmetics. Offsets, sizes, and +// strides are in unit of elements in a linearly laid-out memory, which is the +// same as pointer arithmetic operations in Triton language. Scalar is a +// shortcut used when the entire state describes a single scalar value. Source +// is the base pointer. If order is present, PtrState describes block pointer; +// otherwise it describes non-block pointers. When it describes block pointer, +// shape field means the same field as tt.make_tensor_ptr; when it describes a +// non-block pointer, shape field indicates how address wraps around (i.e., +// modulo); a constant 0 indicates no modulo for the dimension. +struct PtrState { + + SmallVector offsets; + SmallVector strides; + SmallVector shape; + SmallVector sizes; + SmallVector order; + + Value source; + Value scalar; + + int32_t getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + offsets.size() == strides.size()); + return offsets.size(); + } + + // @return true if the `PtrState` structure describes a block pointer, + // otherwise it describes a non-block pointer. + bool isBlockPtr() const { return !order.empty(); } + + // This function checks whether the pointer addresses wraps around on the + // dimention `dim`. + // @return true if the address wraps around, (i.e. has modulo). + // Note that this function should only be called when PtrState describes a + // non-block pointer. + bool dimHasModulo(uint32_t dim) const { + assert( + !isBlockPtr() && + "Analysis should not check modulo if PtrState describes block pointer"); + + assert(dim < getRank() && "Dim cannot be higher than the tensor rank."); + + // When PtrState describes a non-block pointer, shape field indicates how + // address wraps around. As a result, a constant 0 indicates no wrap around + // (i.e. modulo) for the dimension. + return !isConstant(shape[dim], 0); + } + + // @return true if addresses wrap around in any of the pointer dimension. + bool hasModulo() const { + for (int32_t i = 0; i < getRank(); i++) { + if (dimHasModulo(i)) { + return true; + } + } + return false; + } + + bool isEmpty() const { return getRank() == 0 && !source && !scalar; } + + // Process addition of two PtrStates. + LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + Location loc = op->getLoc(); + + if (lhsState.source && rhsState.source) { + op->emitRemark("TritonRaiseBlockPointer: do not support adding two " + "pointer states that both have base pointers"); + return failure(); + } + + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { // both lhs and rhs are scalars + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + ArithBuilder abuilder(builder, loc); + for (uint64_t i = 0; i < lhsState.getRank(); ++i) { + Value newOffset = abuilder.add(lhsState.offsets[i], rhsState.offsets[i]); + offsets.push_back(newOffset); + + Value newStride = abuilder.add(lhsState.strides[i], rhsState.strides[i]); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + } + + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark( + "TritonRaiseBlockPointer: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + assert( + !(lhsState.hasModulo() || rhsState.hasModulo()) || + (lhsState.getRank() <= 2) && + "cannot have rank > 2 if operand one of the operands has a modulo"); + + // dealing with modulo: + // - If lhs has no modulo, skip + // - If rhs has zero offset on dim i, we can just use lhs's modulo + // - Else, the analysis fails + + // An example for the 3rd condition above can look like: + // %0 = tt.splat %scalar + // %1 = tt.splat %ptr + // %2 = tt.arange + // %3 = arith.remsi %2, %size + // %4 = tt.addptr %1, %3 + // %5 = tt.addptr %4, %0 + // %5 may also occur in a loop to increment %4 every iteration. + + const PtrState *lhs = &lhsState; + const PtrState *rhs = &rhsState; + + if (rhs->hasModulo()) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->getRank(); i++) { + if (!lhs->dimHasModulo(i)) { + shape.push_back(lhs->shape[i]); + } else if (isConstant(rhs->offsets[i], 0)) { + shape.push_back(lhs->shape[i]); + } else { + op->emitRemark("TritonRaiseBlockPointer: do not support adding to " + "operand with modulo"); + return failure(); + } + } + + return success(); + } + + LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + Location loc = op->getLoc(); + + assert(!lhsState.source && !rhsState.source && + "Multiplying base pointer does not make sense"); + + assert(!(lhsState.scalar && rhsState.scalar) && + "do not expect to see both lhs and rhs are scalars"); + + // currently do not support both tensors are effectively non-scalar + if (!lhsState.scalar && !rhsState.scalar) { + op->emitRemark("TritonRaiseBlockPointer: only support multiplying " + "pointer states when one of them represent a scalar"); + return failure(); + } + + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) + std::swap(lhs, rhs); + + Value i32Scalar = getValueOrCreateCastToIndexLike( + builder, loc, builder.getI32Type(), rhs->scalar); + Value i64Scalar = getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), rhs->scalar); + ArithBuilder abuilder(builder, loc); + for (const auto &[offset, stride, dim, size] : + llvm::zip(lhs->offsets, lhs->strides, lhs->shape, lhs->sizes)) { + + Value newOffset = + abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI32Type(), offset), + i32Scalar); + Value newStride = + abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), stride), + i64Scalar); + Value newDim = abuilder.mul(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), dim), + i64Scalar); + + offsets.push_back(newOffset); + strides.push_back(newStride); + shape.push_back(newDim); + sizes.push_back(size); + } + + return success(); + } + + triton::MakeTensorPtrOp createTTMakeTensorPtrOp(OpBuilder &builder, + Location loc) { + + SmallVector newOffsets; + SmallVector newStrides; + SmallVector newShape; + ArithBuilder abuilder(builder, loc); + for (const auto &[offset, stride, dim] : + llvm::zip(offsets, strides, shape)) { + + if (isConstant(stride, 0)) { + newOffsets.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI32Type(), offset)); + } else { + auto divOffset = builder.create( + loc, builder.getI32Type(), + getValueOrCreateCastToIndexLike(builder, loc, builder.getI32Type(), + offset), + getValueOrCreateCastToIndexLike(builder, loc, builder.getI32Type(), + stride)); + newOffsets.push_back(divOffset); + } + newStrides.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), stride)); + newShape.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI64Type(), dim)); + } + + auto op = builder.create( + loc, source, newShape, newStrides, newOffsets, sizes, order); + LLVM_DEBUG(llvm::dbgs() << "creating tt.make_tensor_ptr:\n" << op << "\n";); + return op; + } +}; + +#ifndef NDEBUG +template +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const SmallVector &v) { + os << "{"; + if (!v.empty()) { + os << v.front(); + llvm::for_each(ArrayRef(v).drop_front(), + [&os](const T &el) { os << ", " << el; }); + } + return os << "}"; +} + +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const PtrState &state) { + return os << " "; +} +#endif + +struct TritonRaiseBlockPointer + : triton::cpu::impl::TritonRaiseBlockPointerBase< + TritonRaiseBlockPointer> { + using Base::Base; + using IndexMapSet = std::map>; + SmallVector cleanUp; + + void runOnOperation() final { + auto moduleOp = getOperation(); + + if (failed(rewriteOp(moduleOp))) { + moduleOp->emitWarning("TritonRaiseToBlockPointer failed"); + } + + for (auto op : cleanUp) { + if (op->getUsers().empty()) + op->erase(); + } + } + + LogicalResult rewriteOp(Operation *rootOp) { + LLVM_DEBUG({ + llvm::dbgs() << "rewriting rootOp\n"; + rootOp->dump(); + }); + + rootOp->walk([&](Operation *op) { + if (op == rootOp) { + return WalkResult::advance(); + } + return TypeSwitch(op) + .Case([this](triton::AddPtrOp addptr) { + if (failed(rewriteAddPtrOp(addptr))) + addptr->emitRemark( + "TritonRaiseToBlockPointer: Failed to rewrite"); + return WalkResult::advance(); + }) + .Case([&](auto maketptr) { + if (failed(remapMakeTensorPtrOp(maketptr))) { + maketptr->emitRemark("TritonRaiseToBlockPointer: Failed to " + "rewrite MakeTensorPtrOp"); + } + return WalkResult::advance(); + }) + .Case([this](auto loadstore) { + if (failed(rewriteLoadStoreOp(loadstore))) { + loadstore->emitRemark( + "TritonRaiseToBlockPointer: Failed to rewrite"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto forOp) { + if (failed(rewriteForOp(forOp))) { + forOp->emitRemark( + "TritonRaiseToBlockPointer: Failed to rewrite ForOp"); + return WalkResult::interrupt(); + } + return WalkResult::skip(); + }) + .Default([&](auto) { return WalkResult::advance(); }); + }); + + return success(); + } + + LogicalResult rewriteForOp(scf::ForOp op) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + llvm::SmallDenseMap initArgIndexMap; + + OpBuilder builder(op); + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = ptrMap.lookupOrNull(arg); + PtrState state; + if (mappedV) { + if (auto makeTensorPtrOp = + mappedV.getDefiningOp()) { + + if (llvm::any_of(op.getRegionIterArgs()[i].getUsers(), + [](Operation *user) { + return isa(user); + })) { + op->emitRemark("TritonRaiseToBlockPointer: ExpandDims Ops in loops " + "are currently not supported"); + return failure(); + } + + if (succeeded(visitOperandMakeTensorPtr( + makeTensorPtrOp, state, op.getLoc(), builder, true))) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + continue; + } + } else if (auto addptrOp = mappedV.getDefiningOp()) { + // We always use tt.addptr for scalar pointers. If the defininig op is + // tt.addptr and we have a non-scalar pointer, something must have + // gone wrong with the pass. + assert(!isa(addptrOp.getResult().getType()) && + "Result type of AddPtrOp must be a tensor!"); + if (succeeded( + visitOperandAddptr(addptrOp, state, op.getLoc(), builder))) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + continue; + } + } + } + // If any of the analysis failed, or init arg is not pointer related or + // prior rewrite has failed. Pass as is + newInitArgs.push_back(arg); + } + + // For each of the PtrState recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto &[i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(state.offsets)) { + newInitArgs.push_back(s); + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + newInitArgs.push_back(s); + } + + if (state.getRank() == 0) { + assert(state.scalar && + "The state must have a scalar if its rank is equal to zero"); + // for scalar pointers, the scalar contains the offset and is the only + // relevant state that could be updated by the loop. + newInitArgs.push_back(state.scalar); + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to use the newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, state)); + levelToBlockArgIndex[level].insert(i); + } + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = builder.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping cloneMap; + cloneMap.map(op.getInductionVar(), iv); + cloneMap.map(op.getInitArgs(), newInitArgs); + cloneMap.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, cloneMap); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrState fields are converted from init arg to newly created + // block arg + int cnt = op.getRegionIterArgs().size(); + for (auto &[i, state] : knownPtrsTmp) { + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + if (state.getRank() == 0) { + assert(state.scalar && + "The state must have a scalar if its rank is equal to zero"); + state.scalar = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + // Record the PtrState for this pointer + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs[key] = state; + initArgIndexMap[i] = state; + + // For tensors of pointers, create a tt.make_block_ptr at the beginning of + // the loop body that correspond to this region iter arg. In case it is + // used by tt.load/tt.store in the loop body before pointer updates, this + // will make sure rewriteLoadOp/rewriteStoreOp can use the analysis + // result. E.g., given the following input (%tensor_of_ptr is a block + // arg): + // scf.for (%tensor_of_ptr) { + // %data = tt.load %tensor_of_ptr + // // more operations to update %tensor_of_ptr + // } + // We may produce the following output: + // scf.for (%base_ptr, %stride, %offset) { + // %tensor_of_ptr = tt.make_block_ptr(%base_ptr, %stride, %offset) + // %data = tt.load %tensor_of_ptr + // // more operations to update %offset + // } + // If %tensor_of_ptr is not used (i.e., %tensor_of_ptr is updated before + // used in the original IR), it will simply be removed by + // canonicalization. + + // For scalar pointers, there is no need to create a tts.addptr at the + // beginning of the loop body. We don't lower tt.load and tt.store on + // scalars in this pass; pointer arithmetics can also just use the + // original pointer. + if (state.getRank() != 0) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&newOp.getRegion().front()); + triton::MakeTensorPtrOp makePtrOp = + state.createTTMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(key, makePtrOp.getResult()); + knownPtrs[makePtrOp.getResult()] = std::move(state); + } + } + + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto forOp = dyn_cast(bodyOp)) { + forOp->emitRemark( + "TritonRaiseToBlockPointer: nested loops currently not supported"); + return failure(); + } + } + // Update the loop body. + if (failed(rewriteOp(newOp))) { + newOp->erase(); + op->emitRemark("TritonRaiseToBlockPointer: update loop body failed when " + "rewriting for op"); + return failure(); + } + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + if (failed(rewriteYieldOp(yieldOp, initArgIndexMap))) { + newOp->erase(); + return failure(); + }; + } + + levelToBlockArgIndex.erase(level); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + + llvm::dbgs() << "old for\n"; + op->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + op->replaceAllUsesWith(resultsToReplaceWith); + op->erase(); + + return success(); + } + + LogicalResult + rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor) { + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) { + // no need to rewrite this op + return success(); + } + + OpBuilder builder(op); + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below + // gathers PtrState for those values. + SmallVector initArgState; + for (auto [i, v] : llvm::enumerate(op->getOperands())) { + // If this operand is not rewritten by forOp, skip + auto &thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto mappedV = ptrMap.lookupOrNull(v); + if (!mappedV) { + op->emitRemark("Prior rewrite failure lead to yield rewrite failure"); + return failure(); + } + + PtrState state; + LogicalResult ret = failure(); + if (auto makeTPtrOp = mappedV.getDefiningOp()) { + ret = visitOperandMakeTensorPtr(makeTPtrOp, state, op.getLoc(), builder, + true); + } else if (auto addptrOp = mappedV.getDefiningOp()) { + ret = visitOperandAddptr(addptrOp, state, op.getLoc(), builder); + } + if (ret.failed()) { + op->emitRemark("Failed to rewrite yield op"); + return failure(); + } + initArgState.push_back(state); + + // Verify that shape is not updated during the for loop + auto forState = knownPtrsFor[i]; + for (auto i = 0; i < forState.getRank(); ++i) { + if (forState.shape[i] != state.shape[i]) { + // Special case, see comments in addState in dealing with shape/modulo + if (i == 0 && forState.getRank() == 2) { + if (forState.shape[1] == state.shape[0] && + forState.shape[0] == state.shape[1]) { + break; + } + } + op->emitRemark( + "TritonRaiseToBlockPointer: operand's shape/modulo state changed " + "within loop body"); + return failure(); + } + } + } + + SmallVector operands; + for (auto opnd : op->getOperands()) { + auto mappedV = ptrMap.lookupOrNull(opnd); + operands.push_back(mappedV ? mappedV : opnd); + } + + // For each of the PtrState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.offsets) { + operands.push_back(s); + } + + for (auto s : state.strides) { + operands.push_back(s); + } + + if (state.getRank() == 0) { + operands.push_back(state.scalar); + } + } + + auto newOp = builder.create(op->getLoc(), operands); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + op->erase(); + return success(); + } + + LogicalResult remapMakeTensorPtrOp(triton::MakeTensorPtrOp op) { + OpBuilder builder(op); + + PtrState state; + if (failed(visitOperandMakeTensorPtr(op, state, op.getLoc(), builder))) { + return failure(); + } + + knownPtrs[op.getResult()] = std::move(state); + return success(); + } + + Value getFinalValue(Value value) { + auto defOp = value.getDefiningOp(); + if (!defOp) { + // look init values outside the loop + BlockArgument blockArg = dyn_cast(value); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + scf::ForOp forOp = dyn_cast(parentOp); + return forOp ? getFinalValue( + forOp.getInitArgs()[blockArg.getArgNumber() - 1]) + : value; + } + + if (isa(defOp) || isa(defOp) || + isa(defOp) || isa(defOp)) + return getFinalValue(defOp->getOperand(0)); + if (auto addOp = dyn_cast(defOp)) { + if (isConstant(addOp.getLhs(), 0)) + return getFinalValue(addOp.getRhs()); + if (isConstant(addOp.getRhs(), 0)) + return getFinalValue(addOp.getLhs()); + return addOp.getResult(); + } else if (auto mulOp = dyn_cast(defOp)) { + if (isConstant(mulOp.getLhs(), 1)) + return getFinalValue(mulOp.getRhs()); + if (isConstant(mulOp.getRhs(), 1)) + return getFinalValue(mulOp.getLhs()); + return mulOp.getResult(); + } + return value; + } + + bool lookForMulitplyingValueInDefiningPath(Value &val, Value &ref) { + Operation *defOp = getFinalValue(val).getDefiningOp(); + if (!defOp) + return false; + + if (auto mulOp = dyn_cast(defOp)) { + if ((mulOp.getLhs() == ref) || (mulOp.getRhs() == ref)) + return true; + } + return false; + } + + bool areValuesEqual(Value val1, Value val2) { + if (val1 == val2) + return true; + Operation *op1 = val1.getDefiningOp(); + Operation *op2 = val2.getDefiningOp(); + if (op1 && op2) { + auto intVal1 = getFoldedConstantValue(op1); + auto intVal2 = getFoldedConstantValue(op2); + if (intVal1.has_value() && intVal2.has_value()) { + return intVal1.value() == intVal2.value(); + } + } + return false; + } + + int checkIfOffsetMultipliedByStride(Value operand, + SmallVector &strides) { + Operation *defOp = operand.getDefiningOp(); + + SmallVector finalStrides; + // check all strides different + // if not => skip + for (auto stride : strides) { + Value currentVal = getFinalValue(stride); + if (llvm::any_of(finalStrides, [&](Value val) { + return areValuesEqual(val, currentVal); + })) + return -1; + finalStrides.push_back(currentVal); + } + + int axis = 0; + for (auto finalStride : finalStrides) { + // search for a mul to finalStride in the predecessors + if (lookForMulitplyingValueInDefiningPath(operand, finalStride)) + return axis; + if (isConstant(finalStride, 1)) + return axis; + ++axis; + } + return -1; + } + + // Return true if a `triton::ExpandOp` has been found is the defining path. + bool hasExpandOpInDefiningPath(Value value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + // look init values outside the loop + BlockArgument blockArg = dyn_cast(value); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + scf::ForOp forOp = dyn_cast(parentOp); + return forOp ? hasExpandOpInDefiningPath( + forOp.getInitArgs()[blockArg.getArgNumber() - 1]) + : false; + } + + if (isa(defOp)) + return true; + if (isa(defOp)) + return false; + if (isa(defOp)) + return false; + if (isa(defOp) || isa(defOp) || + isa(defOp) || isa(defOp) || + isa(defOp)) + return hasExpandOpInDefiningPath(defOp->getOperand(0)); + if (isa(defOp) || isa(defOp)) + return hasExpandOpInDefiningPath(defOp->getOperand(0)) || + hasExpandOpInDefiningPath(defOp->getOperand(1)); + + return true; + } + + LogicalResult rewriteAddPtrOp(triton::AddPtrOp op) { + OpBuilder builder(op); + Location loc = op.getLoc(); + + PtrState state; + if (failed(visitOperandAddptr(op, state, loc, builder))) + return failure(); + + knownPtrs[op.getResult()] = state; + + Value result = op.getResult(); + Value mapped = result; + if (isa(result.getType())) { + triton::MakeTensorPtrOp makePtrOp = + state.createTTMakeTensorPtrOp(builder, loc); + knownPtrs[makePtrOp.getResult()] = std::move(state); + mapped = makePtrOp.getResult(); + } + + ptrMap.map(result, mapped); + + // AddPtrOps that have been rewritten and no longer used in the code must be + // removed in the pass to avoid type matching issue. + cleanUp.push_back(op); + + return success(); + } + + LogicalResult visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, + PtrState &state, const Location loc, + OpBuilder &builder, + bool addedByPass = false) { + assert(state.isEmpty() && "state is a return argument"); + + if (auto iter = knownPtrs.find(makeTPtrOp.getResult()); + iter != knownPtrs.end()) { + state = iter->second; + return success(); + } + + state.source = makeTPtrOp.getBase(); + + auto resType = cast(makeTPtrOp.getResult().getType()); + auto pointeeType = cast(resType.getPointeeType()); + auto shape = pointeeType.getShape(); + + for (int64_t i = 0; i < pointeeType.getRank(); i++) { + state.sizes.push_back(shape[i]); + + auto strideCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + auto offsetCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + auto scaledOffset = builder.create( + loc, offsetCst.getResult(), strideCst.getResult()); + state.offsets.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getIntegerType(offsetBitwidth), + scaledOffset.getResult())); + } + state.strides = makeTPtrOp.getStrides(); + state.shape = makeTPtrOp.getShape(); + state.order = SmallVector(makeTPtrOp.getOrder()); + + return success(); + } + + LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, + Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + PtrState ptrState; + if (failed(visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), + builder))) { + return failure(); + } + + PtrState offsetState; + if (failed(visitOperand(addptrOp.getOffset(), offsetState, + addptrOp.getLoc(), builder))) { + return failure(); + } + + // The axis to which the offset must be applied need to be known. + // However, in some cases, the pass fails to detect whether an offset should + // be applied to an axis other than the first. We, therefore, try to find + // out if the offset is multiplied by a known stride. Example: + // off += BLOCK_SIZE_K * stride_ak + // Indeed, as the axis of the stride is known with certainty, we can assume + // that if the offset is multiplied by a known stride, the axis of offset + // should correspond to the axis of the stride axis. In the previous + // example, suppose we have strides = [stride_am, stride_ak] but offsets = + // [off, 0] As we found that `off` is multiplied by `stride_ak`, we correct + // the axis of the offsets to align the axis of `off` with axis of + // `stride_ak`. The corrected offsets then become: [0, off] Limitations: + // - this approach based on pattern matching + user code assumptions is + // (very) fragile. + // if user code does not directly multiply the offset by the stride + // value identified by the pass, the analysis will fail. + // - in theory, this correction support should fail if the analysis + // cannot reach a certain level of certainty. + // Typically, if stride values are the same (e.g. [512, 512]), the + // support is unable to determine the right axis and will not correct + // anything. That said, we do not guarantee the current support does + // not give rise to false positive detections. + auto parentOp = addptrOp->getParentOp(); + if (isa(parentOp)) { + // ExpandOp direclty sets offset to the expected axis. + // So if an ExpandOp has been found in defining path, the analysis is + // skipped. + if (!hasExpandOpInDefiningPath(addptrOp.getOffset())) { + auto axis = checkIfOffsetMultipliedByStride(addptrOp.getOffset(), + ptrState.strides); + if (axis >= 1) + std::swap(offsetState.offsets[0], offsetState.offsets[axis]); + } + } + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + LLVM_DEBUG(llvm::dbgs() << "Base: " << ptrState << "\n" + << "Offset: " << offsetState << "\n";); + + return state.addState(ptrState, offsetState, addptrOp, builder); + } + + LogicalResult visitOperand(Value operand, PtrState &state, const Location loc, + OpBuilder &builder) { + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return success(); + } + + if (isa(operand.getType())) { + OpBuilder::InsertionGuard guard(builder); + if (Operation *definingOp = operand.getDefiningOp()) + builder.setInsertionPointAfter(definingOp); + auto castOp = builder.create( + loc, builder.getIndexType(), operand); + state.scalar = castOp.getResult(); + return success(); + } + + if (isa(operand.getType())) { + state.scalar = operand; + return success(); + } + + if (isa(operand.getType())) { + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (Operation *op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) + return visitOperandAddptr(addPtrOp, state, loc, builder); + if (isa(op)) + llvm_unreachable( + "Unexpected operand defining operation tt.make_tensor_ptr"); + llvm_unreachable("Unexpected operand defining operation"); + } + state.source = operand; + return success(); + } + + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) { + llvm::errs() << "TritonRaiseBlockPointer: encountered addptr block " + "argument operand\n" + << operand << "\n"; + } + + return TypeSwitch(definingOp) + .Case( + [this, &state, loc, &builder](auto op) { + return visitAddPointerOperand(op, state, loc, builder); + }) + .Default([](Operation *op) { + llvm::dbgs() << "TritonRaiseBlockPointer: encountered addptr operand " + "produced by an unsupported operation\n" + << op << "\n"; + return failure(); + }); + } + + template + LogicalResult visitAddPointerOperand(OpTy op, PtrState &state, Location loc, + OpBuilder &builder); + + template ::value, + bool> = true> + LogicalResult visitAddPointerRemOperand(OpTy remOp, PtrState &state, + Location loc, OpBuilder &builder); + + template ::value, + bool> = true> + LogicalResult rewriteLoadStoreOp(OpTy op) { + constexpr bool isLoad = std::is_same_v; + constexpr StringLiteral opName = + isLoad ? StringLiteral("loadOp") : StringLiteral("storeOp"); + + Value ptr = ptrMap.lookupOrNull(op.getPtr()); + + if (!ptr) { + op->emitRemark("TritonRaiseBlockPointer: pointer is not replaced with " + "tt.make_tensor_ptr so ") + << opName << " cannot be rewritten"; + return failure(); + } + + auto ptrType = dyn_cast(ptr.getType()); + if (ptrType && !isa(ptrType.getPointeeType())) { + op->emitRemark("TritonRaiseBlockPointer: scalar ") + << opName << " will not be rewritten"; + return failure(); + } + + // As masks are incompatible with block pointer load/store ops + // Masks must be handled before the operation can be rewritten. + // This will be done in a future PR (Issue #1784). + // In the meantime, operations with a mask are not rewrtitten. + if (op.getMask()) { + return success(); + } + + SmallVector boundary; + if (auto iter = knownPtrs.find(ptr); iter != knownPtrs.end()) { + auto state = iter->second; + for (int axis = 0; axis < state.shape.size(); ++axis) { + if (!isConstant(state.shape[axis], 0)) + boundary.push_back(axis); + } + } + ArrayRef newBoundaryCheck(boundary); + + OpBuilder builder(op); + if constexpr (isLoad) { + auto loadOp = builder.create( + op.getLoc(), ptr, newBoundaryCheck, op.getPadding(), op.getCache(), + op.getEvict(), op.getIsVolatile()); + + LLVM_DEBUG(llvm::dbgs() << "creating tt.load: " << loadOp << "\n";); + + op.replaceAllUsesWith(loadOp.getResult()); + } else { + [[maybe_unused]] auto storeOp = builder.create( + op.getLoc(), ptr, op.getValue(), op.getBoundaryCheck(), op.getCache(), + op.getEvict()); + + LLVM_DEBUG(llvm::dbgs() << "creating tt.store: " << storeOp << "\n";); + } + + op->erase(); + return success(); + } + + llvm::SmallDenseMap knownPtrs; + IRMapping ptrMap; + IndexMapSet levelToBlockArgIndex; + int level = 0; +}; + +template < + typename OpTy, + std::enable_if_t< + llvm::is_one_of::value, bool>> +LogicalResult TritonRaiseBlockPointer::visitAddPointerRemOperand( + OpTy remOp, PtrState &state, Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + PtrState rhsState; + if (failed(visitOperand(remOp.getRhs(), rhsState, loc, builder))) { + return failure(); + } + + if (!rhsState.scalar) { + remOp->emitRemark( + "TritonRaiseBlockPointer: only support cases when rhs of remainder " + "contains scalar"); + return failure(); + } + + if (failed(visitOperand(remOp.getLhs(), state, loc, builder))) { + return failure(); + } + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + if (state.hasModulo()) { + remOp->emitRemark("TritonRaiseBlockPointer: do not support multiple modulo " + "within an expression"); + return failure(); + } + + switch (state.getRank()) { + case 1: + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.shape.back() = rhsState.scalar; + break; + case 2: { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.shape[1] = rhsState.scalar; + } else if (shape[1] == 1) { + state.shape[0] = rhsState.scalar; + } else { + remOp->emitRemark("TritonRaiseBlockPointer: taking modulo on a 2D tensor " + "with no singleton dimension not supported"); + return failure(); + } + break; + } + default: + remOp->emitRemark("TritonRaiseBlockPointer: unsupported modulo pattern"); + return failure(); + } + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::RemSIOp remOp, PtrState &state, Location loc, OpBuilder &builder) { + return visitAddPointerRemOperand(remOp, state, loc, builder); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::RemUIOp remOp, PtrState &state, Location loc, OpBuilder &builder) { + return visitAddPointerRemOperand(remOp, state, loc, builder); +} + +template <> +LogicalResult +TritonRaiseBlockPointer::visitAddPointerOperand(triton::MakeRangeOp rangeOp, + PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + ArrayRef shape = cast(rangeOp.getType()).getShape(); + + uint32_t start = rangeOp.getStart(); + uint32_t end = rangeOp.getEnd(); + uint32_t stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back( + builder.create(loc, start, offsetBitwidth)); + state.strides.push_back(builder.create( + loc, stride, shapeAndStridesBitwidth)); + state.shape.push_back( + builder.create(loc, 0, shapeAndStridesBitwidth)); + state.sizes.push_back(shape[0]); + + LLVM_DEBUG(llvm::dbgs() << "MakeRange state: " << state << "\n";); + + return success(); +} + +template <> +LogicalResult +TritonRaiseBlockPointer::visitAddPointerOperand(triton::SplatOp splatOp, + PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + Value src = splatOp.getSrc(); + Value dst = splatOp.getResult(); + ArrayRef dstShape = cast(dst.getType()).getShape(); + + if (failed(visitOperand(src, state, loc, builder))) + return failure(); + + if (!isa(src.getType())) { + splatOp->emitRemark("TritonRaiseBlockPointer: unsupported splat pattern"); + return failure(); + } + + for (int64_t s : dstShape) { + Value c0i32 = builder.create(loc, 0, offsetBitwidth); + Value c0i64 = + builder.create(loc, 0, shapeAndStridesBitwidth); + state.offsets.push_back(c0i32); + state.strides.push_back(c0i64); + state.shape.push_back(c0i64); + state.sizes.push_back(s); + } + + // If we splat a integer value, scalar should become the offset of the + // outer most dimension + if (state.scalar) { + state.offsets[0] = getValueOrCreateCastToIndexLike( + builder, loc, builder.getIntegerType(offsetBitwidth), state.scalar); + } + + LLVM_DEBUG(llvm::dbgs() << "Splat state: " << state << "\n";); + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::AddIOp addOp, PtrState &state, Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + PtrState lhsState; + if (failed(visitOperand(addOp.getLhs(), lhsState, loc, builder))) + return failure(); + + PtrState rhsState; + if (failed(visitOperand(addOp.getRhs(), rhsState, loc, builder))) + return failure(); + + if (failed(state.addState(lhsState, rhsState, addOp, builder))) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Add state: " << state << "\n";); + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::MulIOp mulOp, PtrState &state, Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + PtrState lhsState; + if (failed(visitOperand(mulOp.getLhs(), lhsState, loc, builder))) + return failure(); + + PtrState rhsState; + if (failed(visitOperand(mulOp.getRhs(), rhsState, loc, builder))) + return failure(); + + if (failed(state.mulState(lhsState, rhsState, mulOp, builder))) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "Mul state: " << state << "\n";); + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + arith::ConstantOp op, PtrState &state, Location loc, OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + auto attr = cast(op.getValue()); + Type elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "Expecting constant tensor"); + + state.scalar = builder.create( + loc, attr.getValues()[0].getValue().getSExtValue()); + + Type offsetType = builder.getIntegerType(offsetBitwidth); + auto resultType = cast(op.getResult().getType()); + Value offset = convertScalarToDtype(builder, loc, state.scalar, offsetType, + /*isUnsignedCast=*/true); + state.offsets.push_back(offset); + state.offsets.insert( + state.offsets.end(), resultType.getShape().size() - 1, + builder.create(loc, 0, offsetBitwidth)); + state.strides.insert( + state.strides.end(), resultType.getShape().size(), + builder.create(loc, 0, shapeAndStridesBitwidth)); + state.shape.insert( + state.shape.end(), resultType.getShape().size(), + builder.create(loc, 0, shapeAndStridesBitwidth)); + + for (int32_t dim : resultType.getShape()) { + state.sizes.push_back(dim); + } + + return success(); +} + +template <> +LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( + triton::ExpandDimsOp expandDimsOp, PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + if (failed(visitOperand(expandDimsOp.getSrc(), state, loc, builder))) { + return failure(); + } + + ArrayRef dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + Value c0i32 = builder.create(loc, 0, offsetBitwidth); + Value c0i64 = + builder.create(loc, 0, shapeAndStridesBitwidth); + state.offsets.insert(state.offsets.begin() + axis, c0i32); + state.sizes.insert(state.sizes.begin() + axis, 1); + state.strides.insert(state.strides.begin() + axis, c0i64); + state.shape.insert(state.shape.begin() + axis, c0i64); + + if (state.hasModulo() && state.getRank() > 2) { + expandDimsOp->emitRemark("TritonRaiseBlockPointer: unsupported scenario " + "where expand_dims result " + "has modulo and rank > 2"); + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "ExpandDims state: " << state << "\n";); + + return success(); +} + +template <> +LogicalResult +TritonRaiseBlockPointer::visitAddPointerOperand(triton::BroadcastOp broadcastOp, + PtrState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty() && "state is a return argument"); + + Value src = broadcastOp.getSrc(); + Value dst = broadcastOp.getResult(); + + if (!isa(src.getType())) { + broadcastOp->emitRemark( + "TritonRaiseBlockPointer: Unsupported broadcast source type"); + return failure(); + } + + ArrayRef srcShape = cast(src.getType()).getShape(); + ArrayRef dstShape = cast(dst.getType()).getShape(); + + assert(srcShape.size() <= dstShape.size() && + "rank of source cannot be greater than the rank of destination"); + + if (failed(visitOperand(src, state, loc, builder))) { + return failure(); + } + + if (srcShape.size() == dstShape.size()) { + llvm::copy(dstShape, state.sizes.begin()); + } else { + // Offset must be equal, otherwise we don.t know which offset should be + // propagated to the new axis. + for (int i = 1; i < state.offsets.size(); ++i) { + if (state.offsets[0] != state.offsets[i]) { + broadcastOp->emitRemark( + "TritonRaiseBlockPointer: Unsupported broadcast with different " + "offsets while source rank and destination rank differ."); + return failure(); + } + } + + // Create the new axis. + // The positions of the new axis are determined based and the shape values. + // If shape are the same, the new axis are added at the end. + size_t srcAxis = 0; + for (size_t axis = 0; axis < dstShape.size(); ++axis) { + if ((srcAxis < srcShape.size()) && + (srcShape[srcAxis] == dstShape[axis])) { + ++srcAxis; + continue; + } + Value c0i32 = + builder.create(loc, 0, offsetBitwidth); + Value c0i64 = + builder.create(loc, 0, shapeAndStridesBitwidth); + state.offsets.insert(state.offsets.begin() + axis, + getValueOrCreateCastToIndexLike( + builder, loc, + builder.getIntegerType(offsetBitwidth), + state.offsets[0])); + state.sizes.insert(state.sizes.begin() + axis, dstShape[axis]); + state.strides.insert(state.strides.begin() + axis, c0i64); + state.shape.insert(state.shape.begin() + axis, c0i64); + } + + // The following condition has been duplicated from the expand_dim support + // TODO : Verify if we need still need it given that triton `make_block_ptr` + // op differs from triton-shared `make_block_ptr` op regarding how address + // wrap around are handled. + if (state.hasModulo() && state.getRank() > 2) { + broadcastOp->emitRemark("TritonRaiseBlockPointer: unsupported scenario " + "where broadcast result " + "has modulo and rank > 2"); + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Broadcast state: " << state << "\n";); + + return success(); +} +} // namespace diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..0c097fb5923e --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,25 @@ +add_triton_library(TritonToTritonCPU + ConvertAtomicOps.cpp + ConvertControlFlowOps.cpp + ConvertDebugOps.cpp + ConvertDotOp.cpp + ConvertElementwiseOps.cpp + ConvertElemManipOps.cpp + ConvertHistogramOp.cpp + ScalarizeInterface.cpp + ScalarizeUsingForOps.cpp + ConvertMemoryOps.cpp + ConvertPtrOps.cpp + ConvertReductionOp.cpp + ConvertScanOp.cpp + TypeConverter.cpp + + DEPENDS + TritonToTritonCPUPassIncGen + ScalarizeInterfaceIncGen + MLIRDialectUtils + + LINK_LIBS PUBLIC + TritonCPUIR + MLIRVectorDialect +) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp new file mode 100644 index 000000000000..bab0cd94c57e --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -0,0 +1,212 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTATOMICOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class AtomicConversionTarget : public ConversionTarget { +public: + explicit AtomicConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalOp( + [&](triton::AtomicRMWOp op) -> std::optional { + return converter.isLegal(op) && !op.getMask(); + }); + addDynamicallyLegalOp( + [&](triton::AtomicCASOp op) -> std::optional { + return converter.isLegal(op); + }); + } +}; + +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto mask = + op.getMask() ? rewriter.getRemappedValue(op.getMask()) : nullptr; + arith::ConstantOp maskCst = mask ? getConstMaskDef(mask) : nullptr; + auto rmwOp = op.getAtomicRmwOp(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + + if (mask && !isa(mask.getType())) { + auto res = lowerScalarMaskToCF(loc, rmwOp, ptrs, vals, mask, sem, scope, + rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + Value res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value resElem; + + if (mask && !maskCst) { + // Non-const mask values are lowered to CF. + Value maskVal = rewriter.create(loc, mask, indices); + resElem = lowerScalarMaskToCF(loc, rmwOp, ptr, val, maskVal, sem, scope, + rewriter); + } else if (!mask || + (maskCst && cast(maskCst.getValue()) + .getValues()[idx])) { + // Const true mask case. + resElem = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + + // Elements with const false mask are skipped. + if (resElem) { + res = rewriter.create(loc, resElem, res, indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + Value lowerScalarMaskToCF(Location loc, RMWOp rmwOp, Value ptr, Value val, + Value mask, MemSemantic sem, MemSyncScope scope, + ConversionPatternRewriter &rewriter) const { + // Check for constant mask. + if (auto maskDef = mask.getDefiningOp()) { + auto maskVal = cast(maskDef.getValue()); + if (maskVal.getValue().isZero()) { + return rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + } else { + return rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + } + + auto ifOp = rewriter.create( + loc, mask, + [&](OpBuilder &builder, Location loc) { + Value resVal = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + rewriter.create(loc, resVal); + }, + [&](OpBuilder &builder, Location loc) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + rewriter.create(loc, zero); + }); + return ifOp.getResult(0); + } + + arith::ConstantOp getConstMaskDef(Value mask) const { + while (auto cast = mask.getDefiningOp()) + mask = cast.getOperand(0); + return mask.getDefiningOp(); + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto cmpVals = rewriter.getRemappedValue(op.getCmp()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + auto res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value cmpVal = rewriter.create(loc, cmpVals, indices); + Value resElem = rewriter.create( + loc, val.getType(), ptr, cmpVal, val, sem, scope); + rewriter.create(loc, resElem, res, indices); + } + + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertAtomicOps + : public triton::impl::ConvertAtomicOpsBase { + using ConvertAtomicOpsBase::ConvertAtomicOpsBase; + + ConvertAtomicOps() : ConvertAtomicOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + AtomicConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertAtomicOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp new file mode 100644 index 000000000000..491b647103a7 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -0,0 +1,216 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTCONTROLFLOWOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ControlFlowOpConversionTarget : public ConversionTarget { +public: + explicit ControlFlowOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +struct ForOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lowerBound = rewriter.getRemappedValue(op.getLowerBound()); + Value upperBound = rewriter.getRemappedValue(op.getUpperBound()); + Value step = rewriter.getRemappedValue(op.getStep()); + SmallVector initArgs; + if (failed(rewriter.getRemappedValues(op.getInitArgs(), initArgs))) + return failure(); + // Create new for op with remapped values. + auto newOp = rewriter.create(op.getLoc(), lowerBound, + upperBound, step, initArgs); + // Move the old op block and convert its sigature. + Block *oldBlock = op.getBody(); + Block *newBlock = newOp.getBody(); + rewriter.moveBlockBefore(oldBlock, newOp.getBody()); + rewriter.eraseBlock(newBlock); + if (failed(rewriter.convertRegionTypes(oldBlock->getParent(), + *getTypeConverter()))) + return failure(); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; + +// This is borrowed from SCFWhilePattern in +// lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +class WhileOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +// and +// lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +struct ConvertControlFlowOps + : public triton::impl::ConvertControlFlowOpsBase { + using ConvertControlFlowOpsBase::ConvertControlFlowOpsBase; + + ConvertControlFlowOps() : ConvertControlFlowOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ControlFlowOpConversionTarget convTarget(*context, typeConverter); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertControlFlowOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp new file mode 100644 index 000000000000..80edcf69f239 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -0,0 +1,160 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDEBUGOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DebugOpsConversionTarget : public ConversionTarget { +public: + explicit DebugOpsConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addLegalOp(); + addLegalOp(); + addLegalOp(); + addLegalOp(); + + addIllegalOp(); + addIllegalOp(); + } +}; + +struct PrintOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // It lowers to triton_cpu.print after converting tensor types to vectors. + // (tt.print doesn't accept vector types, so we have this intermediate op.) + if (op.getNumOperands() == 0) { + rewriter.create(loc, op.getPrefix(), op.getHex(), + ValueRange{}, + llvm::SmallVector{}); + rewriter.eraseOp(op); + return success(); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + Value operand = op.getOperands()[i]; + auto isSigned = {op.getIsSigned()[i]}; + if (!isa(operand.getType())) { + rewriter.create( + loc, op.getPrefix(), op.getHex(), + rewriter.getRemappedValue(operand), isSigned); + continue; + } + + auto tensorTy = cast(operand.getType()); + auto elemTy = tensorTy.getElementType(); + if (isa(elemTy)) { + elemTy = rewriter.getI64Type(); + } + MemRefType memRefTy = MemRefType::get(tensorTy.getShape(), elemTy); + + Value allocVal = rewriter.create( + loc, memRefTy, rewriter.getI64IntegerAttr(64)); + + Value vec = rewriter.getRemappedValue(operand); + VectorType vecTy = cast(vec.getType()); + + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(vecTy.getRank(), zeroIdx); + + rewriter.create(loc, vec, allocVal, indices); + + Value allocUnrankedVal = rewriter.create( + loc, UnrankedMemRefType::get(elemTy, memRefTy.getMemorySpace()), + allocVal); + + rewriter.create(loc, op.getPrefix(), op.getHex(), + allocUnrankedVal, isSigned); + + rewriter.create(loc, allocVal); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct AssertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value acc = rewriter.create(loc, i1_ty, + rewriter.getOneAttr(i1_ty)); + Value condition = rewriter.getRemappedValue(op.getCondition()); + SmallVector dimsToReduce( + cast(condition.getType()).getRank(), true); + condition = rewriter.create( + loc, condition, acc, dimsToReduce, vector::CombiningKind::AND); + rewriter.replaceOpWithNewOp(op, condition, + op.getMessage()); + return success(); + } +}; + +struct ConvertDebugOps + : public triton::impl::ConvertDebugOpsBase { + using ConvertDebugOpsBase::ConvertDebugOpsBase; + + ConvertDebugOps() : ConvertDebugOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + DebugOpsConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDebugOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp new file mode 100644 index 000000000000..5d2f3c179ee1 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -0,0 +1,104 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDOTOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DotConversionTarget : public ConversionTarget { +public: + explicit DotConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Value a = rewriter.getRemappedValue(op.getA()); + Value b = rewriter.getRemappedValue(op.getB()); + Value c = rewriter.getRemappedValue(op.getC()); + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto cType = cast(c.getType()); + assert(aType.getRank() == bType.getRank() && + bType.getRank() == cType.getRank() && + (aType.getRank() == 2 || aType.getRank() == 3) && + "Mixed ranks, not 2d or 3d matmul, unknown type of op"); + + rewriter.replaceOpWithNewOp(op, a, b, c, op.getInputPrecision(), + op.getMaxNumImpreciseAcc()); + return success(); + } +}; + +struct ConvertDotOp : public triton::impl::ConvertDotOpBase { + using ConvertDotOpBase::ConvertDotOpBase; + + ConvertDotOp() : ConvertDotOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + DotConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp new file mode 100644 index 000000000000..a39a93e42446 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -0,0 +1,249 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMMANIPOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElemManipOpConversionTarget : public ConversionTarget { +public: + explicit ElemManipOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ReshapeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct TransOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getSrc()); + auto order = op.getOrder(); + SmallVector permutation(order.begin(), order.end()); + rewriter.replaceOpWithNewOp(op, val, permutation); + return success(); + } +}; + +struct JoinOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto interleave = rewriter.create(loc, lhs, rhs); + // JoinOp creates a new dimension, but InterleaveOp doubles the final one. + // Use ShapeCastOp to get the required shape. + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, interleave); + return success(); + } +}; + +struct CatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); + return success(); + } +}; + +struct SplitOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = cast(src.getType()); + auto resTy = getTypeConverter()->convertType(op.getType(0)); + + SmallVector results; + if (srcTy.getRank() == 1) { + results.push_back(rewriter.create(loc, src, 0)); + results.push_back(rewriter.create(loc, src, 1)); + } else { + SmallVector tmpShape({srcTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, srcTy.getElementType()), src); + + SmallVector evenIndices; + SmallVector oddIndices; + for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) { + evenIndices.push_back(i); + oddIndices.push_back(i + 1); + } + + Value res1 = + rewriter.create(loc, tmp, tmp, evenIndices); + Value res2 = + rewriter.create(loc, tmp, tmp, oddIndices); + results.push_back(rewriter.create(loc, resTy, res1)); + results.push_back(rewriter.create(loc, resTy, res2)); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct ConvertElemManipOps + : public triton::impl::ConvertElemManipOpsBase { + using ConvertElemManipOpsBase::ConvertElemManipOpsBase; + + ConvertElemManipOps() : ConvertElemManipOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElemManipOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElemManipOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp new file mode 100644 index 000000000000..87e0914e1e41 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -0,0 +1,277 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElementwiseOpConversionTarget : public ConversionTarget { +public: + explicit ElementwiseOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + + addDynamicallyLegalOp( + [](triton::BitcastOp op) { return isa(op.getType()); }); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ConstantOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + assert(resTy); + if (auto denseAttr = dyn_cast(op.getValueAttr())) { + rewriter.replaceOpWithNewOp(op, resTy, + denseAttr.reshape(resTy)); + } else { + llvm_unreachable("Unexpected constant attribute"); + } + return success(); + } +}; + +struct MulhiUIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getX()); + auto rhs = rewriter.getRemappedValue(op.getY()); + Value res = + rewriter.create(loc, lhs, rhs).getHigh(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ClampFOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getX()); + auto minVal = rewriter.getRemappedValue(op.getMin()); + auto maxVal = rewriter.getRemappedValue(op.getMax()); + Value res; + if (op.getPropagateNanAttr().getValue() == PropagateNan::ALL) { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } else { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct FpToFpOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = src.getType(); + auto resTy = getTypeConverter()->convertType(op.getType()); + auto srcElemTy = isa(srcTy) + ? cast(srcTy).getElementType() + : srcTy; + auto resElemTy = isa(resTy) + ? cast(resTy).getElementType() + : resTy; + + if (srcElemTy.getIntOrFloatBitWidth() > resElemTy.getIntOrFloatBitWidth()) { + std::optional rounding = op.getRounding(); + assert(rounding && "Rounding mode expected for truncate conversions"); + auto roundingAttr = arith::RoundingModeAttr::get( + getContext(), *rounding == RoundingMode::RTZ + ? arith::RoundingMode::toward_zero + : arith::RoundingMode::to_nearest_even); + rewriter.replaceOpWithNewOp(op, resTy, src, roundingAttr, + nullptr); + return success(); + } + + if (srcElemTy.getIntOrFloatBitWidth() < resElemTy.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } + + return failure(); + } +}; + +struct ConvertElementwiseOps + : public triton::impl::ConvertElementwiseOpsBase { + using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; + + ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElementwiseOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElementwiseOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp new file mode 100644 index 000000000000..0bcbfcc9f264 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp @@ -0,0 +1,134 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTHISTOGRAMOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class HistogramConversionTarget : public ConversionTarget { +public: + explicit HistogramConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addIllegalOp(); + } +}; + +struct HistogramOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = dyn_cast(src.getType()); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + if (srcTy.getRank() != 1) + llvm_unreachable("unsupported input for histogram op (rank != 1)"); + + Value zero = rewriter.create( + loc, resTy, rewriter.getZeroAttr(resTy)); + Value one = rewriter.create(loc, resTy, + rewriter.getOneAttr(resTy)); + VectorType cmpVecTy = + VectorType::get(resTy.getShape(), srcTy.getElementType()); + Value rangeVec = rewriter.create( + loc, resTy, makeRangeAttr(cmpVecTy, rewriter)); + Value res = zero; + for (int64_t i = 0; i < srcTy.getShape()[0]; ++i) { + Value idx = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(i)); + Value elem = rewriter.create(loc, src, idx); + Value elemVec = rewriter.create(loc, cmpVecTy, elem); + Value mask = rewriter.create(loc, arith::CmpIPredicate::eq, + elemVec, rangeVec); + Value delta = vector::selectPassthru(rewriter, mask, one, zero); + res = rewriter.create(loc, res, delta); + } + + rewriter.replaceOp(op, res); + + return success(); + } + + TypedAttr makeRangeAttr(VectorType resTy, + ConversionPatternRewriter &rewriter) const { + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(32)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI32VectorAttr(range); + } else if (elemTy.isInteger(64)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI64VectorAttr(range); + } else { + llvm_unreachable( + "unsupported src elem type for histogram (expected i32 or i64)"); + } + } +}; + +struct ConvertHistogramOp + : public triton::impl::ConvertHistogramOpBase { + using ConvertHistogramOpBase::ConvertHistogramOpBase; + + ConvertHistogramOp() : ConvertHistogramOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + HistogramConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertHistogramOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp new file mode 100644 index 000000000000..51729ca6618f --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -0,0 +1,651 @@ +#include "TypeConverter.h" + +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTMEMORYOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct MemoryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getContext; + using OpConversionPattern::getTypeConverter; + + MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, + TypeConverter &typeConverter, MLIRContext *context, + bool useGatherScatter) + : OpConversionPattern(typeConverter, context), + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) { + this->useGatherScatter = useGatherScatter; + } + + Value extractScalarPointer(Location loc, Value ptrs, + ArrayRef indices, + ConversionPatternRewriter &rewriter) const { + // If we build a vector of pointers and the extract a pointer from it, then + // compiler doesn't always optimize it to a simple scalar pointer + // computation. Here we try to follow a data flow of the tensor to rebuild a + // scalar pointer for more efficient resulting code. + if (canComputeScalarValue(ptrs)) { + return computeScalarValue(ptrs.getDefiningOp(), ptrs, indices, rewriter); + } + + // Fall back to a scalar pointer extraction from the vector. + Value ptr = rewriter.create( + loc, rewriter.getRemappedValue(ptrs), indices); + auto ptrTy = dyn_cast(ptrs.getType()).getElementType(); + ptr = rewriter.create(loc, ptrTy, ptr); + return ptr; + } + + Value extractMemRef(Location loc, Value ptr, + ConversionPatternRewriter &rewriter) const { + auto tensorTy = dyn_cast( + dyn_cast(ptr.getType()).getPointeeType()); + auto elemTy = tensorTy.getElementType(); + auto shapeInfo = shapeAnalysis.getPtrShapeInfo(ptr); + Type memRefTy; + if (shapeInfo && shapeInfo->getRank() > 0) { + auto layout = + StridedLayoutAttr::get(getContext(), 0, shapeInfo->getStrides()); + memRefTy = MemRefType::get(shapeInfo->getShape(), elemTy, layout); + } else { + SmallVector dynVals(tensorTy.getRank(), ShapedType::kDynamic); + auto layout = StridedLayoutAttr::get(getContext(), 0, dynVals); + memRefTy = MemRefType::get(dynVals, elemTy, layout); + } + return rewriter.create(loc, memRefTy, ptr); + } + + Value convertOtherVal(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + if (loadOp.getOther()) + return rewriter.getRemappedValue(loadOp.getOther()); + + auto resTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + return rewriter.create( + loadOp.getLoc(), resTy, + SplatElementsAttr::get(resTy, + rewriter.getZeroAttr(resTy.getElementType()))); + } + + Value createAlloca(Location loc, MemRefType ty, Operation *before, + ConversionPatternRewriter &rewriter) const { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(before); + return rewriter.create( + loc, ty, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + } + + // If tensor is not null and its element cannot be recomputed in a scalar + // loop, then store it to a temporary buffer. + Value maybeStoreVecToTempBuf(Location loc, Value vals, Value zeroIdx, + Operation *allocaPoint, + ConversionPatternRewriter &rewriter) const { + if (!vals || canComputeScalarValue(vals)) + return nullptr; + + auto vec = rewriter.getRemappedValue(vals); + auto vecTy = cast(vec.getType()); + auto elemTy = vecTy.getElementType(); + // Memref of i1 assumes one element per byte when we load/store element, + // but vector store (through transfer write) would write 1 bit per element. + if (elemTy.isInteger(1)) { + elemTy = rewriter.getI8Type(); + vec = rewriter.create( + loc, VectorType::get(vecTy.getShape(), elemTy), vec); + } + auto memRefTy = MemRefType::get(vecTy.getShape(), elemTy); + Value memRef = createAlloca(vals.getLoc(), memRefTy, allocaPoint, rewriter); + SmallVector indices(vecTy.getRank(), zeroIdx); + rewriter.create(vals.getLoc(), vec, memRef, + indices); + return memRef; + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysis; + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; + bool useGatherScatter; +}; + +struct LoadOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + static Value + getPaddingValue(Location loc, Type type, + const std::optional &padding, + ConversionPatternRewriter &rewriter) { + auto padding_option = padding.value_or(PaddingOption::PAD_ZERO); + + TypedAttr attr; + switch (padding_option) { + case PaddingOption::PAD_ZERO: + attr = rewriter.getZeroAttr(type); + break; + case PaddingOption::PAD_NAN: + assert(!type.isIntOrIndex()); + auto apNaN = + llvm::APFloat::getNaN(cast(type).getFloatSemantics()); + attr = FloatAttr::get(type, apNaN); + break; + } + + return rewriter.create(loc, attr); + } + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto mask = loadOp.getMask(); + auto ptr = loadOp.getPtr(); + auto boundaryChecks = loadOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (isContiguousRowMajorAccess(axisInfo, loadOp)) { + return lowerToContiguousRowMajor(loadOp, rewriter); + } + if (useGatherScatter && succeeded(lowerToGather(loadOp, rewriter))) { + return success(); + } + return lowerToScalarLoads(loadOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported load op"); + } + + auto memRef = extractMemRef(loc, ptr, rewriter); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + Value padding = getPaddingValue(loc, resTy.getElementType(), + loadOp.getPadding(), rewriter); + auto vecRead = rewriter.create( + loc, resTy, memRef, indices, padding, inBounds); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } + + LogicalResult + lowerToContiguousRowMajor(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores? + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto shape = vecTy.getShape(); + + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type subVecTy = VectorType::get(shape.back(), vecTy.getElementType()); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + Value defaultVal = convertOtherVal(loadOp, rewriter); + Value res = defaultVal; + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subIndices(indices.begin(), + indices.begin() + indices.size() - 1); + auto ptr = extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + Value vec; + if (mask) { + Value subMask = mask; + Value passThru = defaultVal; + if (shape.size() > 1) { + subMask = rewriter.create(loc, mask, subIndices); + passThru = + rewriter.create(loc, defaultVal, subIndices); + } + vec = rewriter.create(loc, subVecTy, memRef, + zeroIdx, subMask, passThru); + } else { + vec = rewriter.create(loc, subVecTy, memRef, zeroIdx); + } + + if (shape.size() > 1) { + res = rewriter.create(loc, vec, res, subIndices); + } else { + res = vec; + } + } + + rewriter.replaceOp(loadOp, res); + return success(); + } + + LogicalResult lowerToGather(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto vecTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto shape = vecTy.getShape(); + + auto [basePtr, offset] = getMemoryBaseOffset(loadOp); + + if (!basePtr || !offset) + return failure(); + + auto pointeeType = + dyn_cast(basePtr.getType()).getPointeeType(); + + auto gatherBase = rewriter.create( + loc, MemRefType::get({}, pointeeType), basePtr); + auto gatherIndices = SmallVector(); + auto gatherIndexVec = rewriter.getRemappedValue(offset); + + Value gatherMask; + if (auto loadMask = loadOp.getMask()) { + gatherMask = rewriter.getRemappedValue(loadMask); + } else { + auto maskType = VectorType::get(shape, rewriter.getI1Type()); + gatherMask = rewriter.create( + loc, maskType, DenseElementsAttr::get(maskType, true)); + } + + auto passThru = convertOtherVal(loadOp, rewriter); + + auto gatherOp = + rewriter.create(loc, vecTy, gatherBase, gatherIndices, + gatherIndexVec, gatherMask, passThru); + rewriter.replaceOp(loadOp, gatherOp); + return success(); + } + + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // Scalar loads and boundary checks are not expected. + assert(loadOp.getBoundaryCheck().empty()); + assert(isa(loadOp.getType())); + + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + + auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); + auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + auto ptrTy = + dyn_cast(loadOp.getPtr().getType()).getElementType(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + auto loadOne = [=, &rewriter](ArrayRef indices, Value dst) { + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = + rewriter.create(loc, ptr, cache, evict, isVolatile); + return rewriter.create(loc, val, dst, indices); + }; + + Value dst = convertOtherVal(loadOp, rewriter); + int64_t numElems = vecTy.getNumElements(); + auto strides = computeStrides(vecTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + if (!mask) { + dst = loadOne(indices, dst); + continue; + } + // Create a conditional block for load if there is a mask. + auto predicate = rewriter.create(loc, mask, indices); + auto ifOp = rewriter.create( + loc, predicate, + [&](OpBuilder &builder, Location loc) { + auto result = loadOne(indices, dst).getResult(); + rewriter.create(loc, result); + }, + [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, dst); + }); + dst = ifOp.getResult(0); + } + + rewriter.replaceOp(loadOp, dst); + + return success(); + } +}; + +struct StoreOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto mask = storeOp.getMask(); + auto ptr = storeOp.getPtr(); + auto boundaryChecks = storeOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (isContiguousRowMajorAccess(axisInfo, storeOp)) { + return lowerToContiguousRowMajor(storeOp, rewriter); + } + if (useGatherScatter && succeeded(lowerToScatter(storeOp, rewriter))) { + return success(); + } + return lowerToScalarStores(storeOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported store op"); + } + + auto value = rewriter.getRemappedValue(storeOp.getValue()); + auto memRef = extractMemRef(loc, ptr, rewriter); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecWrite = rewriter.create(loc, value, memRef, + indices, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } + + LogicalResult + lowerToContiguousRowMajor(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores instead? + auto loc = storeOp.getLoc(); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto vecTy = dyn_cast(vals.getType()); + auto shape = vecTy.getShape(); + + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + auto ptr = extractScalarPointer(loc, storeOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + indices.pop_back(); + auto val = rewriter.create(loc, vals, indices); + + if (mask) { + Value subMask = mask; + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + subMask = rewriter.create(loc, mask, indices); + } + rewriter.create(loc, memRef, zeroIdx, subMask, + val); + } else { + rewriter.create(loc, val, memRef, zeroIdx); + } + } + + rewriter.eraseOp(storeOp); + return success(); + } + + LogicalResult lowerToScatter(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + auto loc = storeOp.getLoc(); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto vecTy = dyn_cast(vals.getType()); + auto shape = vecTy.getShape(); + + auto [basePtr, offset] = getMemoryBaseOffset(storeOp); + + if (!basePtr || !offset) + return failure(); + + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + indices.pop_back(); + + auto val = rewriter.create(loc, vals, indices); + auto indexVec = rewriter.create( + loc, rewriter.getRemappedValue(offset), indices); + Value scatterMask; + + if (mask) { + scatterMask = rewriter.create(loc, mask, indices); + } else { + // Create a mask with all true values if no mask is provided. + auto maskType = VectorType::get({shape.back()}, rewriter.getI1Type()); + scatterMask = rewriter.create( + loc, maskType, DenseElementsAttr::get(maskType, true)); + } + + auto scatterBase = rewriter.create( + loc, MemRefType::get({}, vecTy.getElementType()), basePtr); + auto scatterIndices = SmallVector(); + + rewriter.create(loc, scatterBase, scatterIndices, + indexVec, scatterMask, val); + } + + rewriter.eraseOp(storeOp); + return success(); + } + + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // Scalar stores and boundary checks are not expected. + assert(storeOp.getBoundaryCheck().empty()); + assert(isa(storeOp.getValue().getType())); + + auto loc = storeOp.getLoc(); + auto tensorTy = dyn_cast(storeOp.getPtr().getType()); + + auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); + auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto ptrTy = tensorTy.getElementType(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + auto storeOne = [=, &rewriter](ArrayRef indices) { + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + rewriter.create(loc, ptr, val, cache, evict); + }; + + int64_t numElems = tensorTy.getNumElements(); + auto strides = computeStrides(tensorTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + if (!mask) { + storeOne(indices); + continue; + } + // Create a conditional block for store if there is a mask. + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, + [&](OpBuilder &builder, Location loc) { + storeOne(indices); + rewriter.create(loc); + }); + } + + rewriter.eraseOp(storeOp); + + return success(); + } +}; + +struct CpuStoreOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::cpu::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto value = rewriter.getRemappedValue(storeOp.getSrc()); + auto memRef = storeOp.getDst(); + auto rank = dyn_cast(memRef.getType()).getRank(); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(rank, zeroIdx); + auto vecWrite = + rewriter.create(loc, value, memRef, + indices); //, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } +}; + +struct CpuLoadOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::cpu::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto memRef = loadOp.getSrc(); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(resTy.getRank(), zeroIdx); + auto vecRead = + rewriter.create(loc, resTy, memRef, indices); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } +}; + +class MemoryOpConversionTarget : public ConversionTarget { +public: + explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + + // Allow only scalar loads and stores. + addDynamicallyLegalOp([](triton::LoadOp loadOp) { + return loadOp.getType().isIntOrIndexOrFloat(); + }); + addDynamicallyLegalOp([](triton::StoreOp storeOp) { + return storeOp.getValue().getType().isIntOrIndexOrFloat(); + }); + } +}; + +struct ConvertMemoryOps + : public triton::cpu::impl::ConvertMemoryOpsBase { + ConvertMemoryOps() = default; + + ConvertMemoryOps(bool useGatherScatter) { + this->useGatherScatter = useGatherScatter; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + MemoryOpConversionTarget convTarget(*context); + TritonToTritonCPUTypeConverter pointerConverter; + RewritePatternSet patterns(context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context, + useGatherScatter); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertMemoryOps() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertMemoryOps(bool useGatherScatter) { + return std::make_unique(useGatherScatter); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp new file mode 100644 index 000000000000..27f49a3078c1 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -0,0 +1,197 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTPTROPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +unsigned getElemBitWidth(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + if (auto vectorTy = dyn_cast(type)) + return vectorTy.getElementType().getIntOrFloatBitWidth(); + return type.getIntOrFloatBitWidth(); +} + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Scalar pointer operations are translated directly to LLVM. + addDynamicallyLegalOp( + [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); + addDynamicallyLegalOp([](triton::IntToPtrOp op) { + return op.getSrc().getType().isInteger(); + }); + addDynamicallyLegalOp( + [](triton::AddPtrOp op) { return isa(op.getType()); }); + } +}; + +struct MakeRangeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int32_t start = static_cast(op.getStart()); + int32_t end = static_cast(op.getEnd()); + assert(end >= start); + + llvm::SmallVector values; + values.reserve(end - start); + for (int32_t v = start; v < end; ++v) { + values.push_back(v); + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.create( + op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct SplatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value val = op.getSrc(); + // Cast pointer + if (isa(val.getType())) + val = rewriter.create(loc, rewriter.getI64Type(), val) + .getResult(); + Type resType = getTypeConverter()->convertType(op.getType()); + auto cast = rewriter.create(loc, resType, val); + + rewriter.replaceOp(op, cast); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + unsigned offsetBitWidth = getElemBitWidth(offset.getType()); + unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); + // Scalar case is not expected. + assert(isa(offset.getType())); + assert(isa(ptr.getType())); + VectorType offsetTy = cast(offset.getType()); + VectorType ptrTy = cast(ptr.getType()); + // Build scale vector. i1 elements take 1 byte. + Value scale = rewriter.create( + loc, offsetTy, + SplatElementsAttr::get( + offsetTy, rewriter.getIntegerAttr(offsetTy.getElementType(), + (elemBitWidth + 7) / 8))); + offset = rewriter.create(loc, offset, scale); + if (offsetTy.getElementTypeBitWidth() < ptrTy.getElementTypeBitWidth()) + offset = rewriter.create(loc, ptr.getType(), offset); + rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { + using ConvertPtrOpsBase::ConvertPtrOpsBase; + + ConvertPtrOps() : ConvertPtrOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertPtrOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp new file mode 100644 index 000000000000..6f3f8112ca43 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -0,0 +1,318 @@ +#include "ReduceScanCommon.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTREDUCTIONOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ReductionConversionTarget : public ConversionTarget { +public: + explicit ReductionConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct ReduceOpConversion + : public ReduceScanOpConversionBase { + ReduceOpConversion(bool useReductionOp, bool useMultiDimReductionOp, + const TypeConverter &typeConverter, MLIRContext *context) + : ReduceScanOpConversionBase(typeConverter, context) { + + this->useReductionOp = useReductionOp; + this->useMultiDimReductionOp = useMultiDimReductionOp; + } + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // More simple cases with a single input and a single combine operation + // can be mapped to a vector::MultiDimReductionOp. The resulting code + // depends on a quality of LLVM backend and is not always perfect though. + if (succeeded(mapToReductionOp(op, rewriter, useReductionOp, + useMultiDimReductionOp))) + return success(); + + return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); + } + + SmallVector + lower1DInput(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + SmallVector range(vecSize); + std::iota(range.begin(), range.end(), 0); + + SmallVector dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) { + SmallVector shuffleIndices = range; + for (int64_t i = 0; i < stride; ++i) { + std::swap(shuffleIndices[i], shuffleIndices[i + stride]); + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + res = accumulate(shuffledInput, res, combineOp, rewriter); + } + + // The results are in the first element of each produced vector. + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, res[i], zero); + } + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector res; + for (int64_t idx = 0; idx < shape[0]; ++idx) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + res = accumulate(subInputs, res, combineOp, rewriter); + } + return res; + } + + LogicalResult mapToReductionOp(triton::ReduceOp op, + ConversionPatternRewriter &rewriter, + bool useReductionOp, + bool useMultiDimReductionOp) const { + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return failure(); + + Value src = rewriter.getRemappedValue(op.getOperand(0)); + VectorType srcTy = cast(src.getType()); + + Block *block = op.getBody(); + if (block->getNumArguments() != 2) + return failure(); + Value accArg = block->getArgument(0); + Value itArg = block->getArgument(1); + + auto &blockOps = block->getOperations(); + if (blockOps.size() != 2) + return failure(); + + Operation &retOp = blockOps.back(); + if (!isa(retOp) || retOp.getNumOperands() != 1) + return failure(); + + Value retVal = retOp.getOperand(0); + Operation *defOp = retVal.getDefiningOp(); + if (!defOp || defOp->getNumOperands() != 2) + return failure(); + + Value lhs = defOp->getOperand(0); + Value rhs = defOp->getOperand(1); + if ((lhs != itArg || rhs != accArg) && (lhs != accArg || rhs != itArg)) + return failure(); + + vector::CombiningKind reductionKind; + if (failed(detectReductionKind(defOp, reductionKind))) + return failure(); + + Type resTy = getTypeConverter()->convertType(op.getType(0)); + Value acc = buildInitValue(op.getLoc(), resTy, reductionKind, rewriter); + int64_t axis = op.getAxis(); + + if (useReductionOp && srcTy.getShape().size() == 1) { + rewriter.replaceOpWithNewOp(op, resTy, reductionKind, + src, acc); + return success(); + } else if (useMultiDimReductionOp) { + rewriter.replaceOpWithNewOp( + op, resTy, reductionKind, src, acc, axis); + return success(); + } + + return failure(); + } + + LogicalResult detectReductionKind(Operation *op, + vector::CombiningKind &out) const { + if (isa(op)) + out = vector::CombiningKind::ADD; + else if (isa(op)) + out = vector::CombiningKind::MUL; + else if (isa(op)) + out = vector::CombiningKind::MINSI; + else if (isa(op)) + out = vector::CombiningKind::MINUI; + else if (isa(op)) + out = vector::CombiningKind::MINIMUMF; + else if (isa(op)) + out = vector::CombiningKind::MINNUMF; + else if (isa(op)) + out = vector::CombiningKind::MAXSI; + else if (isa(op)) + out = vector::CombiningKind::MAXUI; + else if (isa(op)) + out = vector::CombiningKind::MAXIMUMF; + else if (isa(op)) + out = vector::CombiningKind::MAXNUMF; + else if (isa(op)) + out = vector::CombiningKind::AND; + else if (isa(op)) + out = vector::CombiningKind::OR; + else if (isa(op)) + out = vector::CombiningKind::XOR; + else + return failure(); + return success(); + } + + Value buildInitValue(Location loc, Type resTy, vector::CombiningKind kind, + ConversionPatternRewriter &rewriter) const { + VectorType vecTy = dyn_cast(resTy); + Type elemTy = vecTy ? vecTy.getElementType() : resTy; + + TypedAttr initVal; + if (kind == vector::CombiningKind::ADD || + kind == vector::CombiningKind::OR || + kind == vector::CombiningKind::XOR || + kind == vector::CombiningKind::MAXUI) + initVal = rewriter.getZeroAttr(elemTy); + else if (kind == vector::CombiningKind::MUL) + initVal = rewriter.getOneAttr(elemTy); + else if (kind == vector::CombiningKind::AND || + kind == vector::CombiningKind::MINUI) + initVal = rewriter.getIntegerAttr(elemTy, -1); + else if (kind == vector::CombiningKind::MAXSI) + initVal = rewriter.getIntegerAttr( + elemTy, + static_cast(-(1UL << (elemTy.getIntOrFloatBitWidth() - 1)))); + else if (kind == vector::CombiningKind::MINSI) + initVal = rewriter.getIntegerAttr( + elemTy, static_cast( + (1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1)); + else if (kind == vector::CombiningKind::MINIMUMF || + kind == vector::CombiningKind::MAXIMUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(std::numeric_limits::quiet_NaN()); + else + llvm_unreachable("Unsupported type for acc init value."); + } + + else if (kind == vector::CombiningKind::MINNUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(std::numeric_limits::infinity()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(std::numeric_limits::infinity()); + else + llvm_unreachable("Unsupported type for acc init value."); + } else if (kind == vector::CombiningKind::MAXNUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(-std::numeric_limits::infinity()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(-std::numeric_limits::infinity()); + else + llvm_unreachable("Unsupported type for acc init value."); + } + + if (vecTy) + initVal = SplatElementsAttr::get(vecTy, initVal); + + return rewriter.create(loc, resTy, initVal); + } + +private: + bool useMultiDimReductionOp; + bool useReductionOp; +}; + +struct ConvertReductionOp + : public triton::cpu::impl::ConvertReductionOpBase { + ConvertReductionOp() = default; + + ConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp) { + this->useReductionOp = useReductionOp; + this->useMultiDimReductionOp = useMultiDimReductionOp; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ReductionConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(useReductionOp, useMultiDimReductionOp, + typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertReductionOp() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp) { + return std::make_unique(useReductionOp, + useMultiDimReductionOp); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp new file mode 100644 index 000000000000..fef15b046621 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -0,0 +1,156 @@ +#include "ReduceScanCommon.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTSCANOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ScanConversionTarget : public ConversionTarget { +public: + explicit ScanConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct ScanOpConversion + : public ReduceScanOpConversionBase { + using ReduceScanOpConversionBase::ReduceScanOpConversionBase; + + SmallVector + lower1DInput(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + Type maskTy = VectorType::get(vecSize, rewriter.getI1Type()); + + SmallVector dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = 1; stride < vecSize; stride *= 2) { + SmallVector shuffleIndices(vecSize, 0); + int64_t start = reverse ? vecSize - 1 - stride : stride; + int64_t end = reverse ? -1 : vecSize; + int64_t step = reverse ? -1 : 1; + for (int64_t i = start; i != end; i += step) { + shuffleIndices[i] = i - step * stride; + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + auto newRes = accumulate(res, shuffledInput, combineOp, rewriter); + + // Number of already computed elements is equal to the current + // stride. Mask them out using a constant mask. + SmallVector maskVals(vecSize, true); + if (reverse) { + std::fill(maskVals.rbegin(), maskVals.rbegin() + stride, false); + } else { + std::fill(maskVals.begin(), maskVals.begin() + stride, false); + } + Value mask = rewriter.create( + loc, maskTy, rewriter.getBoolVectorAttr(maskVals)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = vector::selectPassthru(rewriter, mask, newRes[i], res[i]); + } + } + + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector resTypes; + for (const auto &resTy : op.getResultTypes()) { + resTypes.push_back(VectorType::get( + shape, cast(resTy).getElementType())); + } + SmallVector res = makeEmptyResults(loc, resTypes, rewriter); + SmallVector acc; + int64_t start = reverse ? shape[0] - 1 : 0; + int64_t end = reverse ? -1 : shape[0]; + int64_t step = reverse ? -1 : 1; + for (int64_t idx = start; idx != end; idx += step) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + acc = accumulate(subInputs, acc, combineOp, rewriter); + + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, acc[i], res[i], idx); + } + } + return res; + } +}; + +struct ConvertScanOp : public triton::impl::ConvertScanOpBase { + using ConvertScanOpBase::ConvertScanOpBase; + + ConvertScanOp() : ConvertScanOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ScanConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertScanOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h new file mode 100644 index 000000000000..aaac6a27d5e6 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h @@ -0,0 +1,37 @@ +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +// Generic pattern to rewrite operation by converting types +// for operation operands and results using provided type +// converter. +template +struct OpTypeConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + OperationState newState(op.getLoc(), ResOpT::getOperationName()); + // Convert operands. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.getRemappedValue(operand); + newState.operands.push_back(newOperand); + } + // Convert result types. + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newState.types))) { + return failure(); + } + newState.attributes = op->getAttrs(); + + auto newOp = rewriter.create(newState); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h new file mode 100644 index 000000000000..2a00f087125b --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -0,0 +1,244 @@ +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +namespace cpu { + +// Base class for converting scans and reductions. +// +// It provides accumulation function that clones operations from the +// original combine region and applies them on provided vectors. +// Also, it handles multi-diumensional cases reducing them to two +// possible options: lowering for a 1-D vector inputs and lowering +// the operation over the leading dimension. +// +// Specialized pattern should implement lower1DInput to handle +// trailing dimension case (commonly through shuffles + accumulate) +// and lowerLeadingDimension to handle the leading dimension case +// through accumulation of sub-vectors. +template +struct ReduceScanOpConversionBase : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + virtual SmallVector + lower1DInput(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + virtual SmallVector + lowerLeadingDimension(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rank = cast(op.getOperand(0).getType()).getRank(); + if (op.getAxis() == (rank - 1)) + return lowerTrailingDimension(op, rewriter); + + return lowerNonTrailingDimension(op, rewriter); + } + + // To handle the trailing dimension case, we extract all input vectors + // and process them through lower1DInput, then build the resulting + // vector using inserts. + LogicalResult + lowerTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + SmallVector inputTys(inputs.size()); + std::transform(inputs.begin(), inputs.end(), inputTys.begin(), + [](auto val) { return cast(val.getType()); }); + + // 1-D input case. + if (inputTys.front().getRank() == 1) { + auto res = lower1DInput(inputs, op, rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto shape = inputTys[0].getShape(); + int64_t numElems = inputTys[0].getNumElements(); + auto strides = computeStrides(shape); + // Remove the last stride to produce sub-vector indices. + strides.pop_back(); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + + auto resElems = lower1DInput(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, resElems[i], res[i], + indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // In this case we either call lowerLeadingDimension to process the input + // or extract sub-vectors, call lowerLeadingDimension, and then reconstruct + // the result. + LogicalResult + lowerNonTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + uint32_t axis = op.getAxis(); + if (axis == 0) { + rewriter.replaceOp(op, lowerLeadingDimension(inputs, op, rewriter)); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto vecTy = cast(inputs[0].getType()); + auto shape = vecTy.getShape(); + auto strides = computeStrides(shape); + // Remove trailing elems to build indices of required rank. + strides.erase(strides.begin() + axis, strides.end()); + int64_t numElems = vecTy.getNumElements(); + int64_t step = strides.back(); + for (int64_t idx = 0; idx < numElems; idx += step) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + auto resVecs = lowerLeadingDimension(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = + rewriter.create(loc, resVecs[i], res[i], indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // Accumulate inputs and existing accumulators into a new accumaltors + // applying operations from the combine region. + SmallVector accumulate(ValueRange inputs, ValueRange acc, + Region &combineOp, + ConversionPatternRewriter &rewriter) const { + if (acc.empty()) + return inputs; + + auto shape = cast(inputs[0].getType()).getShape(); + auto &block = combineOp.getBlocks().front(); + IRMapping map; + // Map block arguments to the current inputs and accumulators. + for (unsigned i = 0; i < acc.size(); ++i) { + map.map(block.getArgument(i), acc[i]); + map.map(block.getArgument(acc.size() + i), inputs[i]); + } + for (auto &op : block.getOperations()) { + // Returned values are a new accumulator. + if (isa(op)) { + SmallVector res; + for (auto operand : op.getOperands()) { + res.push_back(map.lookup(operand)); + } + return res; + } + + // Clone operation mapping its inputs and building vector + // result types using the input shape. + OperationState newState(op.getLoc(), op.getName()); + for (auto operand : op.getOperands()) { + newState.operands.push_back( + lookupMappedValue(map, operand, shape, rewriter)); + } + for (auto ty : op.getResultTypes()) { + newState.types.push_back(VectorType::get(shape, ty)); + } + newState.attributes = op.getAttrs(); + auto newOp = rewriter.create(newState); + + // Add new values to the map. + for (auto [oldVal, newVal] : + llvm::zip(op.getResults(), newOp->getResults())) { + map.map(oldVal, newVal); + } + } + llvm_unreachable("No return op found in scan/reduce region"); + } + + Value lookupMappedValue(IRMapping &localMap, Value val, + ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + + Value res = localMap.lookupOrNull(val); + if (!res) { + // If value is not found then it's an invariant defined in the outer + // region. We check if it has been already translated and add a splat + // operation if it hasn't. + res = invariantsMap.lookupOrNull(val); + if (!res) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfterValue(val); + res = rewriter.create( + val.getLoc(), VectorType::get(shape, val.getType()), val); + invariantsMap.map(val, res); + rewriter.restoreInsertionPoint(ip); + } + } + return res; + } + + SmallVector + makeEmptyResults(Location loc, TypeRange resTypes, + ConversionPatternRewriter &rewriter) const { + // Initialize results to zero values. + SmallVector res; + for (auto ty : resTypes) { + res.push_back(rewriter.create( + loc, rewriter.getZeroAttr(getTypeConverter()->convertType(ty)))); + } + return res; + } + + // Dummy vectors are required for shuffles that cannot work on a single + // vector. + SmallVector + createShuffleDummies(Location loc, ValueRange inputs, + ConversionPatternRewriter &rewriter) const { + SmallVector shuffleDummies; + SmallVector dummyShape({1}); + for (auto val : inputs) { + auto ty = cast(val.getType()); + shuffleDummies.push_back(rewriter.create( + loc, + rewriter.getZeroAttr(ty.cloneWith(dummyShape, ty.getElementType())))); + } + return shuffleDummies; + } + +private: + mutable IRMapping invariantsMap; +}; + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp new file mode 100644 index 000000000000..f194d3e195dd --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp @@ -0,0 +1,277 @@ +#include "cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h" + +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; + +#include "cpu/include/ScalarizePass/ScalarizeInterface.cpp.inc" + +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +Value mlir::triton::cpu::computeScalarValue(Operation *scalarizationOp, + Value vals, + ArrayRef indices, + PatternRewriter &rewriter) { + auto scalarized = cast(scalarizationOp); + return scalarized.computeScalarValue(vals, indices, rewriter); +} + +Value mlir::triton::cpu::computeScalarValue(Operation *scalarizationOp, + Value vals, ValueRange indices, + PatternRewriter &rewriter) { + auto scalarized = cast(scalarizationOp); + return scalarized.computeScalarValueForLoop(vals, indices, rewriter); +} + +bool mlir::triton::cpu::canComputeScalarValue(Value vals) { + auto def = vals.getDefiningOp(); + if (!def) + return false; + auto scalarized = dyn_cast(def); + if (!scalarized) + return false; + return scalarized.canComputeScalarValue(vals); +} + +namespace { + +namespace detail { + +template struct value_type_trait { + using type = typename T::value_type; +}; + +template <> struct value_type_trait { + using type = Value; +}; + +template +T createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + llvm_unreachable("Default implementation should be overwritten."); +} + +template <> +int64_t createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + return 0; +} + +template <> +Value createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + return rewriter.create(loc, 0); +} + +} // namespace detail + +// Using ScalariztionFunctor class to partially specialize helper method +template struct ScalariztionFunctor { + template + static Value getScalarValue(OpTy operation, Value vals, T indices, + PatternRewriter &rewriter) { + auto def = vals.getDefiningOp(); + OperationState newState(def->getLoc(), def->getName()); + for (auto operand : def->getOperands()) { + newState.operands.push_back(computeScalarValue( + operand.getDefiningOp(), operand, indices, rewriter)); + } + assert(def->getResults().size() == 1 && + "[Unsupported] Opearation have multiple outputs."); + newState.types.push_back( + cast(def->getResultTypes()[0]).getElementType()); + newState.attributes = def->getAttrs(); + return rewriter.create(newState)->getResult(0); + } +}; + +/// External model implementation of ScalarizeInterface for TritonOps. An +/// external model implementation is used for now till the use of +/// `ScalarizeInterface` is on-par with the current ScalarizeUsingForOp. This +/// allows to register this Interface for all required ops depending on it's +/// type. +template +struct TritonOpScalarizeInterface + : public ScalarizeInterface::ExternalModel, + OpTy> { + bool canComputeScalarValue(Operation *op, Value vals) const { + for (auto operand : op->getOperands()) { + if (isa(operand)) { + return false; + } + auto scalarized = dyn_cast(operand.getDefiningOp()); + if (!scalarized) { + return false; + } + if (!scalarized.canComputeScalarValue(operand)) { + return false; + } + } + return true; + } + + Value computeScalarValue(Operation *op, Value vals, ArrayRef indices, + PatternRewriter &rewriter) const { + OpTy def = vals.getDefiningOp(); + return ScalariztionFunctor().getScalarValue(def, vals, indices, + rewriter); + } + + Value computeScalarValueForLoop(Operation *op, Value vals, ValueRange indices, + PatternRewriter &rewriter) const { + OpTy def = vals.getDefiningOp(); + return ScalariztionFunctor().getScalarValue(def, vals, indices, + rewriter); + } +}; +template <> struct ScalariztionFunctor { + template + Value getScalarValue(SplatOp def, Value vals, T indices, + PatternRewriter &rewriter) { + + return def.getSrc(); + } +}; + +template <> +bool TritonOpScalarizeInterface::canComputeScalarValue( + Operation *op, Value vals) const { + return true; +} + +template <> +struct TritonOpScalarizeInterface + : public ScalarizeInterface::ExternalModel< + TritonOpScalarizeInterface, MakeRangeOp> { + + bool canComputeScalarValue(Operation *op, Value vals) const { return true; } + + Value computeScalarValue(Operation *op, Value vals, ArrayRef indices, + PatternRewriter &rewriter) const { + MakeRangeOp def = vals.getDefiningOp(); + int32_t start = static_cast(def.getStart()); + assert(indices.size() == 1); + Type elemTy = cast(def.getType()).getElementType(); + return rewriter.create( + def.getLoc(), elemTy, + rewriter.getIntegerAttr(elemTy, start + indices[0])); + } + + Value computeScalarValueForLoop(Operation *op, Value vals, ValueRange indices, + PatternRewriter &rewriter) const { + MakeRangeOp def = vals.getDefiningOp(); + assert(indices.size() == 1); + int32_t start = static_cast(def.getStart()); + Type elemTy = cast(def.getType()).getElementType(); + Value startVal = rewriter.create( + def.getLoc(), elemTy, rewriter.getIntegerAttr(elemTy, start)); + Value index = indices[0]; + if (!elemTy.isIndex()) + index = + rewriter.create(def.getLoc(), elemTy, index); + return rewriter.create(def.getLoc(), elemTy, startVal, + index); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(BroadcastOp operation, Value vals, T indices, + PatternRewriter &rewriter) { + BroadcastOp def = operation; + using UnderlyingIndicesType = typename detail::value_type_trait::type; + // Find broadcasted dimensions and replace indices for those + // dimensions with 0 (broadcasted dimension has always size 1). + SmallVector newIndices; + auto sourceTy = cast(def.getSrc().getType()); + auto targetTy = cast(def.getType()); + assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); + for (int64_t i = 0; i < sourceTy.getRank(); ++i) { + if (sourceTy.getShape()[i] != targetTy.getShape()[i]) + newIndices.push_back(detail::createZeroIndex( + std::move(def.getLoc()), rewriter)); + else + newIndices.push_back(indices[i]); + } + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(ExpandDimsOp def, Value vals, T indices, + PatternRewriter &rewriter) { + using UnderlyingIndicesType = typename detail::value_type_trait::type; + // Remove index at expanded dimension. + SmallVector newIndices(indices); + newIndices.erase(newIndices.begin() + def.getAxis()); + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(arith::ConstantOp def, Value vals, T indices, + PatternRewriter &rewriter) { + auto denseVal = cast(def.getValue()); + assert(denseVal.isSplat()); + auto scalarAttr = denseVal.getSplatValue(); + Value res = rewriter.create( + def.getLoc(), scalarAttr.getType(), scalarAttr); + return res; + } +}; + +template <> +bool TritonOpScalarizeInterface::canComputeScalarValue( + Operation *op, Value vals) const { + auto cst = static_cast(op); + if (auto denseVal = dyn_cast(cst.getValue())) { + return denseVal.isSplat(); + } + return false; +} + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(TransOp def, Value vals, T indices, + PatternRewriter &rewriter) { + + using UnderlyingIndicesType = typename detail::value_type_trait::type; + + // Permute indices. + SmallVector newIndices; + auto order = def.getOrder(); + assert(indices.size() == order.size() && "Mismatched rank"); + for (auto idx : order) + newIndices.push_back(indices[idx]); + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +} // namespace + +template static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface>(*ctx); +} + +template static void registerAll(MLIRContext *ctx) { + (registerOne(ctx), ...); +} + +void mlir::triton::cpu::registerTritonOpScalarizeExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, TritonDialect *dialect) { + registerAll(ctx); + }); + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + registerAll(ctx); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp new file mode 100644 index 000000000000..0e8102831e1e --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp @@ -0,0 +1,407 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_SCALARIZEUSINGFOROP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct ScalarizeOpConversion : public OpRewritePattern { + + ScalarizeOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + MLIRContext *context, bool skipGatherScatter) + : OpRewritePattern(context), axisAnalysis(axisInfoAnalysis) { + this->skipGatherScatter = skipGatherScatter; + } + + Value createAlloca(Location loc, MemRefType ty, Operation *before, + PatternRewriter &rewriter) const { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(before); + return rewriter.create( + loc, ty, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + } + + // If tensor is not null and its element cannot be recomputed in a scalar + // loop, then store it to a temporary buffer. + Value storeIfNonScalarizable(Location loc, Value vals, Value zeroIdx, + Operation *allocaPoint, + PatternRewriter &rewriter) const { + // To skip optional values and scalarizable value, that can be computed + // inside loop + if (!vals || canComputeScalarValue(vals)) + return nullptr; + + auto tensor = vals; + auto tensorTy = cast(vals.getType()); + auto elemTy = tensorTy.getElementType(); + if (isa(elemTy)) { + elemTy = IntegerType::get(elemTy.getContext(), 64); + } + // Memref of i1 assumes one element per byte when we load/store element, + // but vector store (through transfer write) would write 1 bit per element. + if (elemTy.isInteger(1)) { + elemTy = rewriter.getI8Type(); + tensor = rewriter.create( + loc, + RankedTensorType::get(tensorTy.getShape(), elemTy, + tensorTy.getEncoding()), + tensor); + } + auto memRefTy = MemRefType::get(tensorTy.getShape(), elemTy); + Value memRef = createAlloca(vals.getLoc(), memRefTy, allocaPoint, rewriter); + SmallVector indices(tensorTy.getRank(), zeroIdx); + rewriter.create(vals.getLoc(), tensor, memRef); + return memRef; + } + + // Load scalar element from a temporary buffer or recompute it if the + // buffer doesn't exist. + Value loadOrComputeScalarValue(Value vals, Value tmpVals, ValueRange indices, + PatternRewriter &rewriter) const { + // Allow null value for easier handling of optional arguments. + if (!vals) + return nullptr; + + // If nothing loaded, value should be scalar computable + if (!tmpVals) { + if (!canComputeScalarValue(vals)) { + llvm::errs() + << "Passed value was not loaded and can't be computed as scalar: " + << vals << "\n"; + llvm::report_fatal_error("Cannot proceed such value"); + return nullptr; + } + return computeScalarValue(vals.getDefiningOp(), vals, indices, rewriter); + } + + // Load value from a temp buffer if any. + Value val = + rewriter.create(vals.getLoc(), tmpVals, indices); + // If we load a pointer then additional cast is needed because tensor of + // pointers is transformed into a vector of integers. + auto elemTy = dyn_cast(vals.getType()).getElementType(); + if (isa(elemTy)) + val = rewriter.create(vals.getLoc(), elemTy, val); + // We need to transform loaded i8 back to i1. + else if (elemTy.isInteger(1)) + val = rewriter.create(val.getLoc(), rewriter.getI1Type(), + val); + return val; + } + + // This is core methods that generates SCF::For + // We are checking arguments and results of operation + // to scalarize them if possible and load/store if they are dynamical + LogicalResult scalarizeWithLoop(OpTy scalarizeOp, + PatternRewriter &rewriter) const { + llvm_unreachable("nope"); + return failure(); + } + + // Method that describes how to check arguments and results of operation + // for scalarization + bool shouldScalarizeOp(OpTy scalarizeOp) const { + llvm_unreachable("nope"); + return false; + } + + // code for Memory Ops, as requires getPtr method + bool shouldScalarizeOpGeneric(OpTy scalarizeOp) const { + + auto ptr = scalarizeOp.getPtr(); + if (triton::isTensorPointerType(ptr.getType())) { + return false; + } + + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (isContiguousRowMajorAccess(axisInfo, scalarizeOp)) { + return false; + } + + auto [basePtr, offset] = getMemoryBaseOffset(scalarizeOp); + if (skipGatherScatter && basePtr && offset) { + return false; + } + + // Scalar memory ops and boundary checks are not expected. + if (!scalarizeOp.getBoundaryCheck().empty()) { + return false; + } + + return ScalarizeOpConversion::shouldScalarizeOp(scalarizeOp); + } + + LogicalResult matchAndRewrite(OpTy scalarOp, + PatternRewriter &rewriter) const override { + + // We want to avoid a code explosion when scalarize loads of big vectors, + // so try to build a scalar loop. + if (shouldScalarizeOpGeneric(scalarOp) && + succeeded(scalarizeWithLoop(scalarOp, rewriter))) + return success(); + return failure(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysis; + bool skipGatherScatter; +}; + +template <> +LogicalResult ScalarizeOpConversion::scalarizeWithLoop( + triton::StoreOp storeOp, PatternRewriter &rewriter) const { + auto loc = storeOp.getLoc(); + + auto ptrs = storeOp.getPtr(); + auto mask = storeOp.getMask(); + auto vals = storeOp.getValue(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + auto tensorTy = cast(vals.getType()); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // Alloca is inserted similar to the load case. + Operation *allocaPoint = storeOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Store a tensor of pointers, mask, and values into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + storeIfNonScalarizable(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + storeIfNonScalarizable(loc, mask, zeroIdx, allocaPoint, rewriter); + Value tmpVals = + storeIfNonScalarizable(loc, vals, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < tensorTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, tensorTy.getShape()[i]); + auto forOp = rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load scalar args. + Value scalarPtr = loadOrComputeScalarValue(ptrs, tmpPtrs, ivs, rewriter); + Value scalarMask = loadOrComputeScalarValue(mask, tmpMask, ivs, rewriter); + Value scalarVal = loadOrComputeScalarValue(vals, tmpVals, ivs, rewriter); + + if (!mask) { + // Regular store case. + auto store_op = rewriter.create(loc, scalarPtr, scalarVal, + cache, evict); + } else { + // Conditional store case + rewriter.create(loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, scalarPtr, scalarVal, cache, evict); + builder.create(loc); + }); + } + + rewriter.eraseOp(storeOp); + return success(); +} + +template <> +bool ScalarizeOpConversion::shouldScalarizeOp( + triton::StoreOp scalarOp) const { + + if (!isa(scalarOp.getValue().getType())) { + return false; + } + + auto tensorTy = cast(scalarOp.getPtr().getType()); + return tensorTy.getNumElements() >= 16; +} + +template <> +LogicalResult ScalarizeOpConversion::scalarizeWithLoop( + triton::LoadOp loadOp, PatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto tensorTy = cast(loadOp.getType()); + + auto ptrs = loadOp.getPtr(); + auto mask = loadOp.getMask(); + auto other = loadOp.getOther(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // There is alloca_scope operation to control alloca scopes. But its usage + // in combination with nested SCF and multi-dimensional vectors make it + // impossible to lower scopes to LLVM using existing MLIR passes. For now, + // simply allocate temp memory in the function's region. + // TODO: Use alloc for big buffers and revisit alloca scoping. + Operation *allocaPoint = loadOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Allocate temp buffer for the result. Write the other value there if + // we cannot write it in a loop. + auto resMemRefTy = + MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()); + Value resMemRef = createAlloca(loc, resMemRefTy, allocaPoint, rewriter); + bool storeOtherInLoop = static_cast(mask); + if (other && !canComputeScalarValue(other)) { + rewriter.create(loc, other, resMemRef); + storeOtherInLoop = false; + } + + // Store a tensor of pointers and mask into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + storeIfNonScalarizable(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + storeIfNonScalarizable(loc, mask, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < tensorTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, tensorTy.getShape()[i]); + auto forOp = rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load a scalar arguments. + Value scalarPtr = loadOrComputeScalarValue(ptrs, tmpPtrs, ivs, rewriter); + Value scalarMask = loadOrComputeScalarValue(mask, tmpMask, ivs, rewriter); + Value scalarOther; + if (storeOtherInLoop) { + if (other) { + scalarOther = + computeScalarValue(other.getDefiningOp(), other, ivs, rewriter); + } else { + scalarOther = rewriter.create( + loc, tensorTy.getElementType(), + rewriter.getZeroAttr(tensorTy.getElementType())); + } + } + + if (!mask) { + // Regular load case. + Value val = rewriter.create(loc, scalarPtr, cache, evict, + isVolatile); + rewriter.create(loc, val, resMemRef, ivs); + } else { + // Conditional load case + rewriter.create( + loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + Value val = builder.create(loc, scalarPtr, cache, + evict, isVolatile); + builder.create(loc, val, resMemRef, ivs); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + if (storeOtherInLoop) + builder.create(loc, scalarOther, resMemRef, ivs); + builder.create(loc); + }); + } + + // Load vector from the temp storage and return it from alloca scope. + rewriter.setInsertionPointAfter(forOps.front()); + SmallVector indices(tensorTy.getRank(), zeroIdx); + Value res = rewriter.create(loc, tensorTy, resMemRef); + rewriter.replaceOp(loadOp, res); + return success(); +} + +template <> +bool ScalarizeOpConversion::shouldScalarizeOp( + triton::LoadOp scalarOp) const { + if (!isa(scalarOp.getType())) { + return false; + } + auto tensorTy = cast(scalarOp.getType()); + return tensorTy.getNumElements() >= 16; +} + +struct ScalarizeUsingForOpPass + : public triton::cpu::impl::ScalarizeUsingForOpBase< + ScalarizeUsingForOpPass> { + using ScalarizeUsingForOpBase::ScalarizeUsingForOpBase; + + ScalarizeUsingForOpPass() : ScalarizeUsingForOpBase() {} + + ScalarizeUsingForOpPass(bool skipGatherScatter) : ScalarizeUsingForOpBase() { + this->skipGatherScatter = skipGatherScatter; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + RewritePatternSet patterns(context); + patterns.add, + ScalarizeOpConversion>( + axisInfoAnalysis, context, skipGatherScatter); + + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) { + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createScalarizeUsingForOpPass() { + return std::make_unique(); +} + +std::unique_ptr> +createScalarizeUsingForOpPass(bool skipGatherScatter) { + return std::make_unique(skipGatherScatter); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp new file mode 100644 index 000000000000..728d353592bb --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -0,0 +1,50 @@ +#include "TypeConverter.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([this](RankedTensorType tensorTy) -> Type { + Type elemTy = convertType(tensorTy.getElementType()); + if (isa(elemTy)) + elemTy = IntegerType::get(tensorTy.getContext(), 64); + return VectorType::get(tensorTy.getShape(), elemTy); + }); + + addArgumentMaterialization([&](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + llvm::errs() << "Inputs: "; + llvm::interleaveComma(inputs, llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Type: " << type << "\n"; + llvm_unreachable("Unexpected argument materizalization"); + }); + + // Converted ops produce vectors instead of tensors. Provide conversion + // here for users. + addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + return builder.create(loc, type, inputs) + .getResult(0); + }); + + // Provide conversion for vector users. + addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + llvm::errs() << "Inputs: "; + llvm::interleaveComma(inputs, llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Type: " << type << "\n"; + llvm_unreachable("Unexpected target materizalization"); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h new file mode 100644 index 000000000000..cb89f0886c60 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H + +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonToTritonCPUTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + TritonToTritonCPUTypeConverter(); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/lib/Xsmm/CMakeLists.txt b/third_party/cpu/lib/Xsmm/CMakeLists.txt new file mode 100644 index 000000000000..62283f5512d4 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/CMakeLists.txt @@ -0,0 +1,33 @@ +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +add_triton_library(TritonCPUXsmm + ConvertTritonToXsmm.cpp + ConvertVectorToXsmm.cpp + VnniUtils.cpp + ValueUtils.cpp + XsmmEnum.cpp + XsmmUtils.cpp + + DEPENDS + TritonCPUXsmmPassIncGen + TritonCPUXsmmAttrDefIncGen + xsmm + + LINK_LIBS PUBLIC + ${extension_libs} + MLIRIR + MLIRPass + MLIRVectorDialect + MLIRMemRefDialect + MLIRFuncDialect + MLIRLLVMDialect + MLIRInferTypeOpInterface + MLIRLinalgUtils + TritonCPUIR + xsmm +) + +target_include_directories(TritonCPUXsmm + PUBLIC + $ +) diff --git a/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp new file mode 100644 index 000000000000..c3b28736ff56 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp @@ -0,0 +1,552 @@ +//===- ConvertTritonToXsmm.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/Passes.h" + +#include "ValueUtils.h" +#include "VnniUtils.h" +#include "XsmmUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" + +#include +#include + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::func; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTTRITONTOXSMM +#define GEN_PASS_DEF_LOOPTOBRGEMMXSMM +#include "cpu/include/Xsmm/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +namespace { + +// Helper from MemoryOpConversion. +// Extract memref out of block pointer. +static Value extractMemRef(PatternRewriter &rewriter, Value ptr, + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis) { + Location loc = ptr.getLoc(); + MLIRContext *ctx = ptr.getContext(); + + auto tensorTy = dyn_cast( + dyn_cast(ptr.getType()).getPointeeType()); + auto elemTy = tensorTy.getElementType(); + auto shapeInfo = shapeAnalysis.getPtrShapeInfo(ptr); + Type memRefTy; + if (shapeInfo && shapeInfo->getRank() > 0) { + auto layout = StridedLayoutAttr::get(ctx, 0, shapeInfo->getStrides()); + memRefTy = MemRefType::get(shapeInfo->getShape(), elemTy, layout); + } else { + SmallVector dynVals(tensorTy.getRank(), ShapedType::kDynamic); + auto layout = StridedLayoutAttr::get(ctx, 0, dynVals); + memRefTy = MemRefType::get(dynVals, elemTy, layout); + } + return rewriter.create(loc, memRefTy, ptr); +} + +static Value getMemrefSource(PatternRewriter &rewriter, Operation *op, + TypedValue operand, + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis) { + Location loc = op->getLoc(); + MLIRContext *ctx = op->getContext(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + RankedTensorType tensorTy = operand.getType(); + + if (auto loadOp = dyn_cast_or_null(operand.getDefiningOp())) { + auto ptr = loadOp.getPtr(); + if (triton::isTensorPointerType(ptr.getType())) { + auto memref = extractMemRef(rewriter, ptr, shapeAnalysis); + auto indices = + rewriter.create(loc, ptr).getResults(); + SmallVector strides(tensorTy.getRank(), 1); + + return rewriter.create( + loc, memref, getAsOpFoldResult(indices), + getAsIndexOpFoldResult(ctx, tensorTy.getShape()), + getAsIndexOpFoldResult(ctx, strides)); + } + } + + MemRefType memTy = + MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()); + auto alloca = rewriter.create(loc, memTy); + rewriter.create(loc, operand, alloca); + + return alloca; +} + +// Helper to move accumulation buffer outside of GEMM reduction loop. +// Returns new accumulation buffer or std::nullopt, otherwise. +// +// Rewrites the following pattern: +// %init = ... tensor<...> +// %0 = scf.for ... iter_args(%acc = %init) +// %res = GEMM(%A, %B, %acc) -> tensor<...> +// scf.yield %res +// consumer(%0) +// into: +// %init = ... tensor<...> +// %hoisted = ... memref<...> +// store %init, %hoisted +// %unused = %scf.for ... iter_args(%acc = %init) +// %res = GEMM(%A, %B, %acc) +// scf.yield %acc +// %0 = load(%hoisted) -> tensor<...> +// consumer(%0) +// +// This rewrite should be used as a part of contraction to memref conversion. +static std::optional +hoistAccumulationBuffer(PatternRewriter &rewriter, Operation *op, + TypedValue operand, + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis) { + Location loc = op->getLoc(); + + // Check if there is any loop around the contraction and if the operand + // comes from loop's arguments. + auto forOp = dyn_cast(op->getParentOp()); + BlockArgument blockArg = dyn_cast(operand); + if (!forOp || !blockArg) + return std::nullopt; + OpOperand *loopArg = forOp.getTiedLoopInit(blockArg); + if (!loopArg) + return std::nullopt; + + // The accumulation iter_arg can be safely moved outside the loop only + // for the following chain: iter_arg -> contraction -> yield + // and there are no other users. + Value res = op->getResults()[0]; + if (!operand.hasOneUse() || !res.hasOneUse() || + !isa(*res.getUsers().begin())) + return std::nullopt; + + // Create a buffer outside the loop. + Value accBuf = getMemrefSource( + rewriter, forOp, dyn_cast>(loopArg->get()), + shapeAnalysis); + + // For simplicity, feed the iter_arg directly into loop yield terminator. + // Canonicalizer will folded them away later. + rewriter.replaceAllUsesWith(res, operand); + + // Replace the corresponding loop result with the latest value read from the + // accumulation buffer. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + + auto loadOp = + rewriter.create(loc, operand.getType(), accBuf); + rewriter.replaceAllUsesWith(forOp.getTiedLoopResult(blockArg), + loadOp.getResult()); + + return accBuf; +} + +struct DotToXsmm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + DotToXsmm(MLIRContext *ctx, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis) + : OpRewritePattern(ctx), shapeAnalysis(shapeInfoAnalysis) { + } + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + Location loc = dotOp.getLoc(); + MLIRContext *ctx = dotOp.getContext(); + + // Dot op computes standard (batch) GEMM. + SmallVector indexingMaps; + TypedValue res = dotOp.getD(); + uint32_t rank = res.getType().getRank(); + if (rank == 2) { + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx)); + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx)); + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx)); + } else if (rank == 3) { + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx)); + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx)); + indexingMaps.push_back( + AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx)); + } + if (indexingMaps.size() == 0) + return rewriter.notifyMatchFailure(dotOp, "unsupported indexing maps"); + + TypedValue lhs = dotOp.getA(); + TypedValue rhs = dotOp.getB(); + TypedValue acc = dotOp.getC(); + + SmallVector flags; + Value lhsBuf = getMemrefSource(rewriter, dotOp, lhs, shapeAnalysis); + Value rhsBuf = getMemrefSource(rewriter, dotOp, rhs, shapeAnalysis); + std::optional hoistedAcc = + hoistAccumulationBuffer(rewriter, dotOp, acc, shapeAnalysis); + Value accBuf = hoistedAcc + ? *hoistedAcc + : getMemrefSource(rewriter, dotOp, acc, shapeAnalysis); + SmallVector inputs{lhsBuf, rhsBuf, accBuf}; + SmallVector outputs{nullptr}; + + // Rewrite matmul into a BRGEMM. + // This allows for additional reduction dimension tiling driven + // by a microkernel. + // + // TODO: Expand heuristics about brgemm rewrite profitability. + // TODO: Allow for batch dimension. + int64_t kDim = lhs.getType().getShape().back(); + auto accShape = acc.getType().getShape(); + constexpr int64_t kTile = 32; + int64_t numTiles = kDim / kTile; + if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) { + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + inputs[0] = rewriter.create( + loc, SmallVector{accShape[0], numTiles, kTile}, inputs[0], + SmallVector{{0}, {1, 2}}); + inputs[1] = rewriter.create( + loc, SmallVector{numTiles, kTile, accShape[1]}, inputs[1], + SmallVector{{0, 1}, {2}}); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + indexingMaps = SmallVector{mapA, mapB, mapC}; + } + + // TODO: Perform this check much earlier before any rewrites. + auto brgemmInfo = xsmm::utils::isMappableToBrgemm(rewriter, dotOp, inputs, + outputs, indexingMaps); + if (failed(brgemmInfo)) { + assert(false); // FIXME: getMemrefSource above already modified IR... + // return rewriter.notifyMatchFailure(dotOp, "not mappable to XSMM"); + } + + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, dotOp, inputs, + indexingMaps, flags); + + if (hoistedAcc) { + // Hoisting already updated all uses correctly. + // Only remove the original contraction. + rewriter.eraseOp(dotOp); + } else { + // Load back the result to bring it back to tensor semantics. + auto loadOp = + rewriter.create(loc, res.getType(), accBuf); + rewriter.replaceOp(dotOp, loadOp); + } + + return success(); + } + +private: + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; +}; + +// Collapse whole reduction loop with a GEMM into equivalent BRGEMM operation. +// Rewrites the following pattern: +// %0 = tt.make_tensor_ptr %base_ptr0 : tensor +// %1 = tt.make_tensor_ptr %base_ptr1 : tensor +// %res:3 = scf.for %arg3 = %lb to %ub step %step +// iter_args(%acc = %init_val, %ptr_A = %0, %ptr_B = %1) +// %A = tt.load %ptr_A +// %B = tt.load %ptr_B +// %dot = tt.dot %A, %B, %acc +// %ptr_A_next = tt.advance %ptr_A, [0, %stepK] +// %ptr_B_next = tt.advance %ptr_B, [%stepK, %0] +// scf.yield %dot, %ptr_A_next, %ptr_B_next +// into: +// %A = tt.make_tensor_ptr %base_ptr0 : tensor +// %B = tt.make_tensor_ptr %base_ptr1 : tensor +// %res0 = BRGEMM %A, %B, %init_val +// %res1 = tt.advance %A, [0, ((%ub - %lb) / %step) * %stepK] +// %res2 = tt.advance %B, [((%ub - %lb) / %step) * %stepK, 0] +struct DotReductionLoopToBrgemm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + DotReductionLoopToBrgemm(MLIRContext *context, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, + PatternBenefit benefit = 10) + : OpRewritePattern(context, benefit), + shapeAnalysis(shapeInfoAnalysis) {} + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + Location loc = dotOp.getLoc(); + MLIRContext *ctx = dotOp.getContext(); + + // Check if there is any loop around the contraction and if the accumulation + // value comes from loop's arguments. + TypedValue acc = dotOp.getC(); + if (acc.getType().getRank() != 2) + return rewriter.notifyMatchFailure(dotOp, "expects 2D GEMM"); + + auto forOp = dyn_cast(dotOp->getParentOp()); + BlockArgument accBbArg = dyn_cast(acc); + if (!forOp || !accBbArg) + return rewriter.notifyMatchFailure(dotOp, "not a reduction loop"); + OpOperand *accArg = forOp.getTiedLoopInit(accBbArg); + if (!accArg) + return rewriter.notifyMatchFailure( + dotOp, "expects iter_args accumulation value"); + // TODO: Relax this check. It is needed to collapse whole loop but + // alternatively only BRGEMM could be pulled out. + if (forOp.getNumRegionIterArgs() != 3) + return rewriter.notifyMatchFailure(dotOp, "invalid number of iter_args"); + + // Assume that the loop's range and all pointer advances are known + // statically. Thus, the induction variable should be unused. + Value loopIv = forOp.getInductionVar(); + if (!loopIv.use_empty()) + return rewriter.notifyMatchFailure(dotOp, + "expects unused induction variable"); + + // The subgraph should a simple reduction loop containing a GEMM operation. + // Validate presence of the following chain: + // iter_arg -> contraction -> yield + // and that there are no other users. + TypedValue res = dotOp.getD(); + if (!acc.hasOneUse() || !res.hasOneUse() || + !isa(*res.getUsers().begin())) + return rewriter.notifyMatchFailure(dotOp, "GEMM subgraph does not match"); + + auto loadMatA = dotOp.getA().getDefiningOp(); + auto loadMatB = dotOp.getB().getDefiningOp(); + if (!loadMatA || !loadMatB) + return rewriter.notifyMatchFailure(dotOp, "expect GEMM input loads"); + if (!loadMatA->hasOneUse() || !loadMatB->hasOneUse()) + return rewriter.notifyMatchFailure(dotOp, + "Input loads subgraph does not match"); + + // Constrain input pointers to the following subgraph: + // iter_arg -> (load, increment) -> yield + BlockArgument lhsBbArg = dyn_cast(loadMatA.getPtr()); + BlockArgument rhsBbArg = dyn_cast(loadMatB.getPtr()); + if (!lhsBbArg || !rhsBbArg) + return rewriter.notifyMatchFailure(dotOp, "expect block arg pointers"); + OpOperand *lhsArg = forOp.getTiedLoopInit(lhsBbArg); + OpOperand *rhsArg = forOp.getTiedLoopInit(rhsBbArg); + if (!lhsArg || + std::distance(lhsBbArg.use_begin(), lhsBbArg.use_end()) != 2 || + !rhsArg || std::distance(rhsBbArg.use_begin(), rhsBbArg.use_end()) != 2) + return rewriter.notifyMatchFailure(dotOp, "expect iter_args pointers"); + + // Input sources should be block pointers. + // TODO: Account for transposed GEMM operands. + auto lhsBlockPtr = dyn_cast_or_null( + lhsArg->get().getDefiningOp()); + auto rhsBlockPtr = dyn_cast_or_null( + rhsArg->get().getDefiningOp()); + if (!lhsBlockPtr || lhsBlockPtr.getOrder() != ArrayRef{1, 0} || + !rhsBlockPtr || rhsBlockPtr.getOrder() != ArrayRef{1, 0}) + return rewriter.notifyMatchFailure(dotOp, "expected block pointers"); + + // Check for pointer increments and validate their steps. + // Each input is expected to advance only in its reduction dimension. + auto lhsAdvanceOp = forOp.getTiedLoopYieldedValue(lhsBbArg) + ->get() + .getDefiningOp(); + auto rhsAdvanceOp = forOp.getTiedLoopYieldedValue(rhsBbArg) + ->get() + .getDefiningOp(); + if (!lhsAdvanceOp || !rhsAdvanceOp) + return rewriter.notifyMatchFailure(dotOp, "expected ptr advance"); + if (!lhsAdvanceOp->hasOneUse() || !rhsAdvanceOp->hasOneUse()) + return rewriter.notifyMatchFailure( + dotOp, "Ptr increment subgraph does not match"); + + auto resShape = res.getType().getShape(); + auto lhsShape = dotOp.getA().getType().getShape(); + auto lhsPtrOffsets = lhsAdvanceOp.getOffsets(); + auto lhsStepParallel = getConstantIntValue(lhsPtrOffsets[0]); + auto lhsStepReduction = getConstantIntValue(lhsPtrOffsets[1]); + if (!lhsStepParallel || *lhsStepParallel != 0 || !lhsStepReduction || + *lhsStepReduction != lhsShape[1]) + return rewriter.notifyMatchFailure(dotOp, "invalid lhs increments"); + + auto rhsPtrOffsets = rhsAdvanceOp.getOffsets(); + auto rhsStepReduction = getConstantIntValue(rhsPtrOffsets[0]); + auto rhsStepParallel = getConstantIntValue(rhsPtrOffsets[1]); + if (!rhsStepReduction || *rhsStepReduction != *lhsStepReduction || + !rhsStepParallel || *rhsStepParallel != 0) + return rewriter.notifyMatchFailure(dotOp, "invalid rhs increments"); + + // Collapse the loop and create equivalent BRGEMM operation. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + + // TODO: Validate if number of tiles cleanly divides the source buffer. + auto loopRange = rewriter.create(loc, forOp.getUpperBound(), + forOp.getLowerBound()); + Value numTiles = + rewriter.create(loc, loopRange, forOp.getStep()); + numTiles = rewriter.create(loc, rewriter.getIndexType(), + numTiles); + auto kStepCst = + rewriter.create(loc, *lhsStepReduction); + auto fullKDimLength = + rewriter.create(loc, numTiles, kStepCst); + + // Create new mmeref views spanning the whole reduction dimension. + SmallVector strides(2, 1); + auto lhsMemref = extractMemRef(rewriter, lhsBlockPtr, shapeAnalysis); + auto lhsIndices = + rewriter.create(loc, lhsBlockPtr) + .getResults(); + auto lhsBuf = rewriter.create( + loc, lhsMemref, getAsOpFoldResult(lhsIndices), + SmallVector{getAsIndexOpFoldResult(ctx, resShape[0]), + getAsOpFoldResult(fullKDimLength)}, + getAsIndexOpFoldResult(ctx, strides)); + + auto rhsMemref = extractMemRef(rewriter, rhsBlockPtr, shapeAnalysis); + auto rhsIndices = + rewriter.create(loc, rhsBlockPtr) + .getResults(); + auto rhsBuf = rewriter.create( + loc, rhsMemref, getAsOpFoldResult(rhsIndices), + SmallVector{getAsOpFoldResult(fullKDimLength), + getAsIndexOpFoldResult(ctx, resShape[1])}, + getAsIndexOpFoldResult(ctx, strides)); + + Value accBuf = + getMemrefSource(rewriter, forOp, + dyn_cast>( + accArg->get().getDefiningOp()->getResult(0)), + shapeAnalysis); + + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + SmallVector lhsOutSizes{ + getAsIndexOpFoldResult(ctx, resShape[0]), getAsOpFoldResult(numTiles), + getAsIndexOpFoldResult(ctx, *lhsStepReduction)}; + auto expandA = rewriter.create( + loc, + SmallVector{resShape[0], ShapedType::kDynamic, + *lhsStepReduction}, + lhsBuf, SmallVector{{0}, {1, 2}}, lhsOutSizes); + SmallVector rhsOutSizes{ + getAsOpFoldResult(numTiles), + getAsIndexOpFoldResult(ctx, *rhsStepReduction), + getAsIndexOpFoldResult(ctx, resShape[1])}; + auto expandB = rewriter.create( + loc, + SmallVector{ShapedType::kDynamic, *rhsStepReduction, + resShape[1]}, + rhsBuf, SmallVector{{0, 1}, {2}}, rhsOutSizes); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + SmallVector indexingMaps{mapA, mapB, mapC}; + + // Create single equivalent BRGEMM. + SmallVector inputs{expandA, expandB, accBuf}; + SmallVector flags; + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, dotOp, inputs, + indexingMaps, flags); + + // Load back the result to bring it back to tensor semantics. + auto loadOp = + rewriter.create(loc, res.getType(), accBuf); + + // Increment the base pointers such that the whole loop can be removed. + // TODO: Revisit this part. + // Only the BRGEMM could be pulled out of the loop and the rest + // could be left as is. + Value zero = rewriter.create(loc, 0); + Value reductionStepConst = + rewriter.create(loc, *lhsStepReduction); + Value reductionOffset = + rewriter.create(loc, reductionStepConst, numTiles); + auto advanceA = rewriter.create( + loc, lhsBlockPtr.getResult().getType(), lhsBlockPtr, + ValueRange{zero, reductionOffset}); + auto advanceB = rewriter.create( + loc, rhsBlockPtr.getResult().getType(), rhsBlockPtr, + ValueRange{reductionOffset, zero}); + + rewriter.replaceOp(forOp, + ValueRange{loadOp.getResult(), advanceA.getResult(), + advanceB.getResult()}); + + return success(); + } + +private: + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; +}; + +struct ConvertTritonToXsmm + : public triton::cpu::impl::ConvertTritonToXsmmBase { + using ConvertTritonToXsmmBase::ConvertTritonToXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + + RewritePatternSet patterns(context); + patterns.add(context, shapeInfoAnalysis); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +struct LoopToBrgemmXsmm + : public triton::cpu::impl::LoopToBrgemmXsmmBase { + using LoopToBrgemmXsmmBase::LoopToBrgemmXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + + RewritePatternSet patterns(context); + patterns.add(context, shapeInfoAnalysis); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp new file mode 100644 index 000000000000..bbc6412a142b --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp @@ -0,0 +1,253 @@ +//===- ConvertVectorToXsmm.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/Passes.h" + +#include "ValueUtils.h" +#include "VnniUtils.h" +#include "XsmmUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::func; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTVECTORTOXSMM +#include "cpu/include/Xsmm/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +namespace { + +static Value getMemrefSource(PatternRewriter &rewriter, Operation *op, + TypedValue operand) { + Location loc = op->getLoc(); + MLIRContext *ctx = op->getContext(); + + if (isa(operand.getType())) + return operand; + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + if (auto readOp = + dyn_cast_or_null(operand.getDefiningOp())) { + VectorType vecTy = readOp.getVectorType(); + SmallVector strides(vecTy.getRank(), 1); + return rewriter.create( + loc, readOp.getSource(), getAsOpFoldResult(readOp.getIndices()), + getAsIndexOpFoldResult(ctx, vecTy.getShape()), + getAsIndexOpFoldResult(ctx, strides)); + } + + auto vecTy = dyn_cast(operand.getType()); + assert(vecTy && "Expect vector type operand"); + MemRefType memTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + auto alloca = rewriter.create(loc, memTy); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(memTy.getRank(), zeroIdx); + auto write = + rewriter.create(loc, operand, alloca, indices); + + return alloca; +} + +// Helper to move accumulation buffer outside of GEMM reduction loop. +// Returns new accumulation buffer or std::nullopt, otherwise. +// +// Rewrites the following pattern: +// %init = ... vector<...> +// %0 = scf.for ... iter_args(%acc = %init) +// %res = GEMM(%A, %B, %acc) -> vector<...> +// scf.yield %res +// consumer(%0) +// into: +// %init = ... vector<...> +// %hoisted = ... memref<...> +// store %init, %hoisted +// %unused = %scf.for ... iter_args(%acc = %init) +// %res = GEMM(%A, %B, %acc) +// scf.yield %acc +// %0 = load(%hoisted) -> vector<...> +// consumer(%0) +// +// This rewrite should be used as a part of contraction to memref conversion. +static std::optional hoistAccumulationBuffer(PatternRewriter &rewriter, + Operation *op, + TypedValue operand) { + Location loc = op->getLoc(); + + // Expect the contraction op to still be in vector abstraction. + auto vecTy = dyn_cast(operand.getType()); + if (!vecTy) + return std::nullopt; + + // Check if there is any loop around the contraction and if the operand + // comes from loop's arguments. + auto forOp = dyn_cast(op->getParentOp()); + BlockArgument blockArg = dyn_cast(operand); + if (!forOp || !blockArg) + return std::nullopt; + OpOperand *loopArg = forOp.getTiedLoopInit(blockArg); + if (!loopArg) + return std::nullopt; + + // The accumulation iter_arg can be safely moved outside the loop only + // for the following chain: iter_arg -> contraction -> yield + // and there are no other users. + Value res = op->getResults()[0]; + if (!operand.hasOneUse() || !res.hasOneUse() || + !isa(*res.getUsers().begin())) + return std::nullopt; + + // Create a buffer outside the loop. + Value accBuf = getMemrefSource(rewriter, forOp, loopArg->get()); + + // For simplicity, feed the iter_arg directly into loop yield terminator. + // Canonicalizer will folded them away later. + rewriter.replaceAllUsesWith(res, operand); + + // Replace the corresponding loop result with the latest value read from the + // accumulation buffer. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(dyn_cast(accBuf.getType()).getRank(), + zeroIdx); + auto readOp = + rewriter.create(loc, vecTy, accBuf, indices); + rewriter.replaceAllUsesWith(forOp.getTiedLoopResult(blockArg), + readOp.getResult()); + + return accBuf; +} + +struct ContractToXsmm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + MLIRContext *ctx = contractOp.getContext(); + + TypedValue lhs = contractOp.getLhs(); + TypedValue rhs = contractOp.getRhs(); + TypedValue acc = contractOp.getAcc(); + + auto accVecTy = dyn_cast(acc.getType()); + if (!accVecTy) + return rewriter.notifyMatchFailure(contractOp, + "expects to accumulate on vector"); + + SmallVector flags; + Value lhsBuf = getMemrefSource(rewriter, contractOp, lhs); + Value rhsBuf = getMemrefSource(rewriter, contractOp, rhs); + std::optional hoistedAcc = + hoistAccumulationBuffer(rewriter, contractOp, acc); + Value accBuf = + hoistedAcc ? *hoistedAcc : getMemrefSource(rewriter, contractOp, acc); + + SmallVector inputs{lhsBuf, rhsBuf, accBuf}; + SmallVector outputs{nullptr}; + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + + // Rewrite matmul into a BRGEMM. + // This allows for additional reduction dimension tiling driven + // by a microkernel. + // + // TODO: Expand heuristics about brgemm rewrite profitability. + // TODO: Allow for batch dimension. + int64_t kDim = lhs.getType().getShape().back(); + auto accShape = accVecTy.getShape(); + constexpr int64_t kTile = 32; + int64_t numTiles = kDim / kTile; + uint32_t rank = accVecTy.getRank(); + if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) { + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + inputs[0] = rewriter.create( + loc, SmallVector{accShape[0], numTiles, kTile}, inputs[0], + SmallVector{{0}, {1, 2}}); + inputs[1] = rewriter.create( + loc, SmallVector{numTiles, kTile, accShape[1]}, inputs[1], + SmallVector{{0, 1}, {2}}); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + indexingMaps = SmallVector{mapA, mapB, mapC}; + } + + auto brgemmInfo = xsmm::utils::isMappableToBrgemm( + rewriter, contractOp, inputs, outputs, indexingMaps); + if (failed(brgemmInfo)) { + assert(false); // FIXME: getMemrefSource above already modified IR... + // return rewriter.notifyMatchFailure(contractOp, "not mappable to XSMM"); + } + + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, contractOp, inputs, + indexingMaps, flags); + + if (hoistedAcc) { + // Hoisting already updated all uses correctly. + // Only remove the original contraction. + rewriter.eraseOp(contractOp); + } else { + // Load back the result to bring it back to vector semantics. + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices( + dyn_cast(accBuf.getType()).getRank(), zeroIdx); + auto readOp = rewriter.create(loc, accVecTy, + accBuf, indices); + rewriter.replaceOp(contractOp, readOp); + } + + return success(); + } +}; + +struct ConvertVectorToXsmm + : public triton::cpu::impl::ConvertVectorToXsmmBase { + using ConvertVectorToXsmmBase::ConvertVectorToXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.cpp b/third_party/cpu/lib/Xsmm/ValueUtils.cpp new file mode 100644 index 000000000000..566665dbc7f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.cpp @@ -0,0 +1,146 @@ +//===- ValueUtils.cpp --------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ValueUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val) { + return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()); +} + +// Returns true if the attribute represent "all zeros" +static bool isZeroAttr(Attribute attribute) { + return TypeSwitch(attribute) + .Case([](auto attr) { return attr.getValueAsDouble() == 0.0; }) + .Case([](auto attr) { return attr.getInt() == 0; }) + .Case([](auto attr) { + if (!attr.getElementType().isIntOrFloat()) + return false; + if (!attr.isSplat()) + return false; + auto splat = attr.template getSplatValue(); + return isZeroAttr(splat); + }) + .Default([](auto attr) { return false; }); +} + +// Prototypes +bool isZeroOp(Operation *); + +// Returns true if the value represents a zero filled tensor. +// Recurse into isZeroOp for defining ops if not immediately obvious +// Looks past linalg generic's argument (which don't have defining ops) +bool isZeroTensor(Value val) { + if (!val) + return false; + if (isValConstZero(val)) + return true; + + Operation *defOp = nullptr; + + // Block arguments don't have a defining op, but they do have an op arg + if (auto arg = dyn_cast(val)) { + // We need to find the argument to the linalg on the same order as this one + auto *linalgOp = arg.getParentRegion()->getParentOp(); + if (!isa(linalgOp)) + return false; + auto index = arg.getArgNumber(); + auto linalgArg = linalgOp->getOperand(index); + defOp = linalgArg.getDefiningOp(); + } else { + defOp = val.getDefiningOp(); + } + return isZeroOp(defOp); +} + +// Returns true if the operation represents a zero filled tensor +// Recurses into isZeroTensor for operands and isZeroAttr for attributes +bool isZeroOp(Operation *defOp) { + if (!defOp) + return false; + + return TypeSwitch(defOp) + .Case([&](auto op) { + // Dense attributes don't match APFloat.isZero() + auto attr = op.getValue(); + return isZeroAttr(attr); + }) + .Case([&](auto op) { + if (op.getInputs().size() != 1) + return false; + return isZeroTensor(op.getInputs()[0]); + }) + .Case( + [&](auto op) { return isZeroTensor(op.getSource()); }) + .Case([&](auto op) { + auto name = op.getName(); + auto module = defOp->getParentOfType(); + auto global = module.lookupSymbol(name); + auto attr = global.getInitialValueAttr(); + return isZeroAttr(attr); + }) + .Default([&](Operation *op) { return false; }); +} + +FailureOr> getStaticStrides(Value value) { + auto valueType = value.getType(); + if (!isa(valueType)) + return failure(); + auto memrefType = cast(valueType); + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +std::pair getPtrAndOffset(OpBuilder &builder, Value operand, + Location loc) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + +} // namespace utils +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/ValueUtils.h b/third_party/cpu/lib/Xsmm/ValueUtils.h new file mode 100644 index 000000000000..8cd50146d41c --- /dev/null +++ b/third_party/cpu/lib/Xsmm/ValueUtils.h @@ -0,0 +1,50 @@ +//===- ValueUtils.h - -------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VALUEUTILS_H +#define TPP_TRANSFORMS_UTILS_VALUEUTILS_H + +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class OpBuilder; +class Operation; +class Location; +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val); + +// Returns true if the op defining `val` represents a zero filled tensor. +bool isZeroTensor(Value val); + +// Returns true if the operation represents a zero filled tensor. +bool isZeroOp(Operation *); + +// Returns the strides of `val`. The method returns something usefull +// only if the `val` type is a strided memref and the strides are statically +// known. +FailureOr> getStaticStrides(Value val); + +// Return the offset and ptr for `val`. Assert if `val` +// is not a memref. +std::pair getPtrAndOffset(OpBuilder &builder, Value val, + Location loc); + +} // namespace utils +} // namespace mlir + +#endif // TPP_TRANSFORMS_UTILS_VALUEUTILS_H diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.cpp b/third_party/cpu/lib/Xsmm/VnniUtils.cpp new file mode 100644 index 000000000000..6df29f993c60 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.cpp @@ -0,0 +1,89 @@ +//===- VNNIUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "VnniUtils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" + +#include "libxsmm.h" + +namespace mlir { +namespace vnni { +namespace utils { + +std::optional getVnniBlockingFactor(Type type) { + auto elementType = getElementTypeOrSelf(type); + if (elementType.isBF16()) + return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + return std::nullopt; +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) { + if (memref.getRank() != static_cast(expectedRank) || + !memref.getElementType().isBF16()) { + return false; + } + return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); +} + +bool isInVnniLayout(int64_t expectedRank, VectorType vector) { + if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) { + return false; + } + return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); +} + +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) { + return isInVnniLayout((int64_t)expectedRank, vector); +} + +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap map, int64_t blockingFactor) { + ArrayRef results = map.getResults(); + SmallVector iteratorTypes = + linalgOp.getIteratorTypesArray(); + + AffineExpr vnniDim = results.back(); + auto dimExpr = dyn_cast(vnniDim); + if (!dimExpr || iteratorTypes[dimExpr.getPosition()] != + mlir::utils::IteratorType::reduction) { + return failure(); + } + + for (auto result : results) { + auto blockeDim = dyn_cast(result); + if (!blockeDim) + continue; + if (blockeDim.getKind() != AffineExprKind::FloorDiv) + continue; + auto lhsDim = dyn_cast(blockeDim.getLHS()); + auto rhsCst = dyn_cast(blockeDim.getRHS()); + if (!lhsDim || !rhsCst) + continue; + if (iteratorTypes[lhsDim.getPosition()] != + mlir::utils::IteratorType::reduction) + continue; + if (rhsCst.getValue() != blockingFactor) + continue; + return lhsDim; + } + return failure(); +} + +} // namespace utils +} // namespace vnni +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/VnniUtils.h b/third_party/cpu/lib/Xsmm/VnniUtils.h new file mode 100644 index 000000000000..e8517a5d23e1 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/VnniUtils.h @@ -0,0 +1,62 @@ +//===- VnniUtils.h -----------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_VNNIUTILS_H +#define TPP_TRANSFORMS_UTILS_VNNIUTILS_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Support/LogicalResult.h" + +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class MemRefType; +class OpOperand; +class AffineDimExpr; +class AffineMap; + +namespace linalg { +class GenericOp; +} // namespace linalg + +namespace vnni { +namespace utils { + +enum class VnniOperandRank { + TRANSPOSE = 3, + GEMM = 3, + BRGEMM_INS = 4, + BRGEMM_OUTS = 3 +}; + +// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8. +std::optional getVnniBlockingFactor(Type type); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); + +bool isInVnniLayout(int64_t expectedRank, VectorType vector); + +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector); + +// Return the first AffineDimExpr in the map `affineMap` +// with a VNNI layout pattern (AffineDimExpr floordiv VNNI). +FailureOr isInVnniLayout(linalg::GenericOp linalgOp, + AffineMap affineMap, + int64_t blockingFactor); + +} // namespace utils +} // namespace vnni +} // namespace mlir + +#endif diff --git a/third_party/cpu/lib/Xsmm/XsmmEnum.cpp b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp new file mode 100644 index 000000000000..85766e5272f0 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmEnum.cpp @@ -0,0 +1,15 @@ +//===- XsmmEnum.cpp - Xsmm dialect enum -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpu/include/Xsmm/XsmmEnum.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::xsmm; + +#include "cpu/include/Xsmm/XsmmEnum.cpp.inc" diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.cpp b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp new file mode 100644 index 000000000000..5def81e089fa --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.cpp @@ -0,0 +1,1069 @@ +//===- XsmmUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "XsmmUtils.h" +#include "ValueUtils.h" +#include "VnniUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" + +#include +#include + +#define DEBUG_TYPE "xsmm-utils" + +using namespace mlir; +using namespace mlir::linalg; + +namespace mlir { +namespace xsmm { +namespace utils { + +// Callable object to verify if `operand` has static shape. +struct HasStaticShape { + HasStaticShape() = default; + HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + if (auto shapedType = dyn_cast_or_null(operandType)) { + if (!shapedType.hasStaticShape()) + return false; + if (shape) { + for (int64_t shapeOnDim : shapedType.getShape()) + shape->push_back(shapeOnDim); + } + } + return true; + } + SmallVectorImpl *shape = nullptr; +}; + +// Callable object to verify if `operand` has static strides. +// If `operand` is a tensor type or a scalar, return true. +struct HasStaticStrides { + HasStaticStrides() = default; + HasStaticStrides(SmallVector *strides) : strides(strides){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + SmallVector strides; + if (auto memRefType = dyn_cast_or_null(operandType)) { + int64_t offset; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return false; + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return false; + } + if (this->strides) + this->strides->append(strides.begin(), strides.end()); + } + return true; + } + SmallVectorImpl *strides = nullptr; +}; + +// Return the position of `dim` in the codomain of `operand`. +std::optional getPosInCodomain(unsigned dim, AffineMap map, + MLIRContext *ctx) { + return map.getResultPosition(getAffineDimExpr(dim, ctx)); +} + +static SmallVector +createFlatListOfOperandStaticDims(Operation *contractOp) { + SmallVector res; + for (OpOperand &opOperand : contractOp->getOpOperands()) + llvm::append_range( + res, dyn_cast(opOperand.get().getType()).getShape()); + return res; +} + +static SmallVector +computeStaticLoopSizes(Operation *contractOp, ArrayRef maps) { + AffineMap map = concatAffineMaps(maps); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector allShapeSizes = + createFlatListOfOperandStaticDims(contractOp); + SmallVector res(numDims, 0); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = dyn_cast(result)) + res[d.getPosition()] = allShapeSizes[idx]; + } + return res; +} + +static FailureOr> +getVNNIStaticStrides(MemRefType valueType) { + SmallVector strides; + int64_t offset; + SmallVector shape; + for (size_t i = 0; i < valueType.getShape().size(); i++) { + shape.push_back(valueType.getShape()[i]); + } + auto temp = shape[shape.size() - 1]; + shape[shape.size() - 1] = shape[shape.size() - 2]; + shape[shape.size() - 2] = temp; + auto memrefType = MemRefType::get(shape, valueType.getElementType()); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +FailureOr>> +dimPositionsMNKBatch(ArrayRef indexingMaps) { + auto contractionDims = inferContractionDims(indexingMaps); + if (failed(contractionDims)) + return failure(); + + unsigned posM = contractionDims->m.back(); + unsigned posN = contractionDims->n.back(); + unsigned posK; + std::optional posBatch = std::nullopt; + + auto pos1stContractionDimInIterSpace = contractionDims->k[0]; + if (contractionDims->k.size() == 1) { + posK = pos1stContractionDimInIterSpace; + } else if (contractionDims->k.size() == 2) { + auto pos2ndContractionDimInIterSpace = contractionDims->k[1]; + + if (pos1stContractionDimInIterSpace < pos2ndContractionDimInIterSpace) { + posBatch = pos1stContractionDimInIterSpace; + posK = pos2ndContractionDimInIterSpace; + } else { + posK = pos1stContractionDimInIterSpace; + posBatch = pos2ndContractionDimInIterSpace; + } + } else { // i.e., when contractionDims->k.size() == [0] or in [3,...] + LLVM_DEBUG(llvm::dbgs() << "too many/few contraction dims\n"); + return failure(); + } + + return std::tuple(posM, posN, posK, posBatch); +} + +// Check if the given +// generic is mappable to a +// brgemm xsmm op. +// - It is a contraction, +// with: +// -- 1 m and 1 n and 2 k +// dimensions. +// -- m appears on the LHS +// and OUT but not in RHS. +// -- n appears on the RHS +// and OUT but not in LHS. +// -- k and k' appear on the +// RHS and LHS but not OUT. +// -- the stride of the +// minor dimension for A, k +// is 1. +// -- the stride of the +// minor dimension for B, n +// is 1. +// -- the stride of the +// minor dimension for C, n +// is 1. +LogicalResult isMappableToBrgemm(PatternRewriter &rewriter, + Operation *contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap) { + auto ctx = contractOp->getContext(); + + auto numDims = indexingMap[0].getNumDims(); + auto contractionDims = inferContractionDims(indexingMap); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() + << "[isMappableToBrgemm] Failed to infer dim kinds"); + return failure(); + } + + assert(inputs.size() == 3); + Value A = inputs[0]; + Value B = inputs[1]; + Value C = inputs[2]; + + unsigned posM = contractionDims->m.back(); + unsigned posN = contractionDims->n.back(); + unsigned posK; + std::optional posBatch = std::nullopt; + + { + auto pos1stContractionDimInIterSpace = contractionDims->k[0]; + if (contractionDims->k.size() == 1) { + posK = pos1stContractionDimInIterSpace; + } else if (contractionDims->k.size() == 2) { + auto pos2ndContractionDimInIterSpace = contractionDims->k[1]; + + if (pos1stContractionDimInIterSpace < pos2ndContractionDimInIterSpace) { + posBatch = pos1stContractionDimInIterSpace; + posK = pos2ndContractionDimInIterSpace; + } else { + posK = pos1stContractionDimInIterSpace; + posBatch = pos2ndContractionDimInIterSpace; + } + } else { // i.e., when contractionDims->k.size() == [0] or in [3,...] + LLVM_DEBUG(llvm::dbgs() << "too many contraction dims\n"); + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Candidate dims: \n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] m: " << posM << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] n: " << posN << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] k: " << posK << "\n"); + if (posBatch) + LLVM_DEBUG(llvm::dbgs() + << "[isMappableToBrgemm] batch: " << posBatch << "\n"); + else + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] no batch dim\n"); + + // Assume that if the last two dimensions are reductions, it is VNNI format. + // TODO: Add proper checks for VNNI. + bool isVnni = contractionDims->k.back() == (numDims - 1) && + contractionDims->k.end()[-2] == (numDims - 2); + + if (isVnni) { + auto dataTypeA = getDataType(rewriter, A.getType()); + auto stridesOnA = getVNNIStaticStrides(dyn_cast(A.getType())); + auto minorDimPosInAsCodomain = getPosInCodomain(posK, indexingMap[0], ctx); + if (failed(stridesOnA) || ((dataTypeA.getValue() != DataType::BF16) && + (*stridesOnA)[*minorDimPosInAsCodomain] != 1)) + return failure(); + + auto dataTypeB = getDataType(rewriter, B.getType()); + auto stridesOnB = getVNNIStaticStrides(dyn_cast(B.getType())); + auto minorDimPosInBsCodomain = getPosInCodomain(posN, indexingMap[1], ctx); + if (failed(stridesOnB) || ((dataTypeB.getValue() != DataType::BF16) && + (*stridesOnB)[*minorDimPosInBsCodomain] != 1)) + return failure(); + + auto dataTypeC = getDataType(rewriter, C.getType()); + auto stridesOnC = getVNNIStaticStrides(dyn_cast(C.getType())); + auto minorDimPosInCsCodomain = getPosInCodomain(posN, indexingMap[2], ctx); + if (failed(stridesOnC) || ((dataTypeC.getValue() != DataType::BF16) && + (*stridesOnC)[*minorDimPosInCsCodomain] != 1)) + return failure(); + } + + return success(); +} + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type) { + auto elemType = getElementTypeOrSelf(type); + if (elemType.isFloat8E5M2()) + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF8); + if (elemType.isBF16()) + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16); + return DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::F32); +} + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + UnaryInfo unaryInfo; + unaryInfo.m = outputShapedType.getShape()[0]; + unaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldi = 1; + if (ShapedType inputShapedType = dyn_cast(input.getType())) { + auto stridesOnInput = mlir::utils::getStaticStrides(input); + if (failed(stridesOnInput) || stridesOnInput->back() != 1 || + !inputShapedType.hasStaticShape()) { + return failure(); + } + + // If we are broascasting a row into cols, the leading + // dimension is 1, same for scalar broadcast. + if (inputFlag == UnaryFlags::BCAST_ROW || + inputFlag == UnaryFlags::BCAST_SCALAR) { + ldi = 1; + } + // If we are broascasting a col into rows, the leading + // dimension is the size of the tensor. + else if (inputFlag == UnaryFlags::BCAST_COL) { + ldi = inputShapedType.getShape().back(); + } else { + ldi = stridesOnInput->front(); + } + } + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + + unaryInfo.ldi = ldi; + unaryInfo.ldo = stridesOnOutput->front(); + return unaryInfo; +} + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output) { + Type outputType = output.getType(); + + assert(isa(outputType)); + auto outputShapedType = cast(outputType); + if (outputShapedType.getRank() != 2 || !outputShapedType.hasStaticShape() || + !isa(outputShapedType.getElementType())) { + return failure(); + } + + BinaryInfo binaryInfo; + binaryInfo.m = outputShapedType.getShape()[0]; + binaryInfo.n = outputShapedType.getShape()[1]; + + int64_t ldiLhs = 1; + if (ShapedType lhsShapedType = dyn_cast(lhs.getType())) { + auto stridesOnLhs = mlir::utils::getStaticStrides(lhs); + if (failed(stridesOnLhs) || stridesOnLhs->back() != 1 || + !lhsShapedType.hasStaticShape()) { + return failure(); + } + + if (lhsFlag == BinaryFlags::BCAST_SCALAR_IN_0 || + lhsFlag == BinaryFlags::BCAST_ROW_IN_0) { + ldiLhs = 1; + } else if (lhsFlag == BinaryFlags::BCAST_COL_IN_0) { + ldiLhs = lhsShapedType.getShape().back(); + } else { + ldiLhs = stridesOnLhs->front(); + } + } + + int64_t ldiRhs = 1; + if (ShapedType rhsShapedType = dyn_cast(rhs.getType())) { + auto stridesOnRhs = mlir::utils::getStaticStrides(rhs); + if (failed(stridesOnRhs) || stridesOnRhs->back() != 1 || + !rhsShapedType.hasStaticShape()) { + return failure(); + } + + if (rhsFlag == BinaryFlags::BCAST_SCALAR_IN_1 || + rhsFlag == BinaryFlags::BCAST_ROW_IN_1) { + ldiRhs = 1; + } else if (rhsFlag == BinaryFlags::BCAST_COL_IN_1) { + ldiRhs = rhsShapedType.getShape().back(); + } else { + ldiRhs = stridesOnRhs->front(); + } + } + + binaryInfo.ldiLhs = ldiLhs; + binaryInfo.ldiRhs = ldiRhs; + + auto stridesOnOutput = mlir::utils::getStaticStrides(output); + if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) + return failure(); + binaryInfo.ldo = stridesOnOutput->front(); + return binaryInfo; +} + +// Examples: +// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], higher=[a, a], [a] reshaped into [1, a]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +static void +computeBcastShapeInput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Initialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherRankDim; + int64_t lowerRankDim; + + for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; + i--, j--) { + higherRankDim = higherRankShape[i]; + lowerRankDim = lowerRankShape[j]; + + if (lowerRankDim == 1 && higherRankDim > 1) + reshapeOutputShape[i] = 1; + else if ((lowerRankDim > 1 && higherRankDim == 1) || + (lowerRankDim == higherRankDim)) + reshapeOutputShape[i] = lowerRankDim; + else if (higherRankDim != lowerRankDim) + assert(false && "bCast semantics for identity op broken"); + } +} + +FailureOr getUnaryFlags(Type inputType, Type outputType) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(inputType) || + cast(inputType).getRank() == 0) { + return xsmm::UnaryFlags::BCAST_SCALAR; + } + + ArrayRef shapeOutput = cast(outputType).getShape(); + ArrayRef shapeInput = cast(inputType).getShape(); + assert(shapeOutput.size() >= shapeInput.size() && + "output rank must be >= input rank"); + SmallVector bShapeInput; + computeBcastShapeInput(shapeOutput, shapeInput, bShapeInput); + assert(shapeOutput.size() == bShapeInput.size()); + shapeInput = bShapeInput; + + // Same shape for input and output, no bcast. + if (shapeInput == shapeOutput) + return xsmm::UnaryFlags::NONE; + + // Input is a memref but it is all ones, bcast = scalar. + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(shapeInput, isOne)) + return xsmm::UnaryFlags::BCAST_SCALAR; + + if (shapeInput[1] == 1 && shapeOutput[1] > 1) + return xsmm::UnaryFlags::BCAST_ROW; + + if (shapeInput[0] == 1 && shapeOutput[0] > 1) + return xsmm::UnaryFlags::BCAST_COL; + + return failure(); +} + +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber) { + assert(shapeOutput.size() >= shapeOperand.size() && + "Output rank must be >= operand rank"); + SmallVector bOperandShape; + computeBcastShapeInput(shapeOutput, shapeOperand, bOperandShape); + assert(shapeOutput.size() == bOperandShape.size()); + assert(shapeOutput.size() == 2); + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto getBCastEnum = [](BCastType bCastType, + OperandPos operandPos) -> xsmm::BinaryFlags { + switch (bCastType) { + case BCastType::NONE: + return xsmm::BinaryFlags::NONE; + case BCastType::SCALAR: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + else + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + case BCastType::ROW: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_ROW_IN_0; + else + return xsmm::BinaryFlags::BCAST_ROW_IN_1; + case BCastType::COL: + if (operandPos == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_COL_IN_0; + else + return xsmm::BinaryFlags::BCAST_COL_IN_1; + } + assert(false && "unrechable"); + abort(); + }; + + if (bOperandShape == shapeOutput) + return getBCastEnum(BCastType::NONE, operandNumber); + + auto isOne = [](int64_t val) { return val == 1; }; + if (llvm::all_of(bOperandShape, isOne)) + return getBCastEnum(BCastType::SCALAR, operandNumber); + + if (bOperandShape[1] == 1 && shapeOutput[1] > 1) + return getBCastEnum(BCastType::ROW, operandNumber); + + if (bOperandShape[0] == 1 && shapeOutput[0] > 1) + return getBCastEnum(BCastType::COL, operandNumber); + + return failure(); +} + +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getLeadingDim(Type type, size_t pos) { + // Not shaped type, the leading dimension is the single scalar. + auto memref = dyn_cast(type); + if (!memref) + return 1; + // For 1d memref we cannot use the stride as leading dimension, but the + // leading dimension is the dimension itself. + if (memref.getRank() == 1) + return memref.getShape()[0]; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memref, strides, offset))) + return failure(); + // fail if the strides are non-constant + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) + return failure(); + return strides[pos]; +} + +static bool isInnerMostDim(OpOperand *operand, unsigned minorDim, + vector::ContractionOp contractOp, DataTypeAttr dtype, + int operandNumber) { + auto shapedType = cast(operand->get().getType()); + int64_t rank = shapedType.getRank(); + if (dtype == + DataTypeAttr::get(contractOp.getContext(), xsmm::DataType::BF16) && + (operandNumber == 1 || operandNumber == 0)) { + return minorDim == rank - 2; + } + return minorDim == rank - 1; +} + +// Emit a transpose operation for `operand` by swapping `dim` with `newDim`. +// Emit a transpose operation for `operand` by swapping the dimensions at index +// `dim` with `newDim`. +static void emitTransposeOnOperand(RewriterBase &rewriter, + vector::ContractionOp contractOp, + Value operand, unsigned dim, unsigned newDim, + int operandNumber) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(contractOp); + + Location loc = contractOp.getLoc(); + auto operandType = cast(operand.getType()); + auto rank = operandType.getRank(); + SmallVector shape = llvm::to_vector(operandType.getShape()); + auto permutation = llvm::to_vector(llvm::seq(0, rank)); + std::swap(permutation[dim], permutation[newDim]); + assert(isPermutationVector(permutation)); + LLVM_DEBUG(llvm::interleaveComma( + permutation, llvm::dbgs() << "[emitTransposeOnOperand] Perm: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + applyPermutationToVector(shape, permutation); + auto vectorType = VectorType::get( + shape, cast(operand.getType()).getElementType()); + Value transposeResult = rewriter.create( + loc, vectorType, operand, permutation); + + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap operandMap = indexingMaps[operandNumber]; + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] Old map: " << operandMap + << "\n"); + SmallVector mapResults = llvm::to_vector(operandMap.getResults()); + applyPermutationToVector(mapResults, permutation); + AffineMap newMap = + AffineMap::get(operandMap.getNumDims(), operandMap.getNumSymbols(), + mapResults, contractOp.getContext()); + LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] New map: " << newMap + << "\n"); + indexingMaps[operandNumber] = newMap; + // TODO: We probably cannot update the result in place. + rewriter.modifyOpInPlace(contractOp, [&]() { + contractOp->setOperand(operandNumber, transposeResult); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); +} + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, DataTypeAttr type) { + MLIRContext *ctx = rewriter.getContext(); + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + OpOperand &operandC = contractOp->getOpOperand(2); + + // C(m,n) += A(m,k) * B(k,n) + // n is expected to be the innermost for C + // k is expected to be the innermost for A + // n is expected to be the innermost for B + auto minorKInCodomainOpA = xsmm::utils::getPosInCodomain( + k, contractOp.getIndexingMapsArray()[0], ctx); + auto minorMInCodomainOpA = xsmm::utils::getPosInCodomain( + m, contractOp.getIndexingMapsArray()[0], ctx); + if (!minorKInCodomainOpA || !minorMInCodomainOpA) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for A\n"); + return failure(); + } + + auto minorNInCodomainOpB = xsmm::utils::getPosInCodomain( + n, contractOp.getIndexingMapsArray()[1], ctx); + auto minorKInCodomainOpB = xsmm::utils::getPosInCodomain( + k, contractOp.getIndexingMapsArray()[1], ctx); + if (!minorNInCodomainOpB || !minorKInCodomainOpB) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for B\n"); + return failure(); + } + + auto minorNInCodomainOpC = xsmm::utils::getPosInCodomain( + n, contractOp.getIndexingMapsArray()[2], ctx); + auto minorMInCodomainOpC = xsmm::utils::getPosInCodomain( + m, contractOp.getIndexingMapsArray()[2], ctx); + if (!minorNInCodomainOpC || !minorMInCodomainOpC) { + LLVM_DEBUG( + llvm::dbgs() + << "[makeMinorDimensionsInnerMost] did not find minor dims for C\n"); + return failure(); + } + + if (!isInnerMostDim(&operandC, *minorNInCodomainOpC, contractOp, type, 2)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for C\n"); + assert( + isInnerMostDim(&operandC, *minorMInCodomainOpC, contractOp, type, 2)); + if (isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorNInCodomainOpB, *minorKInCodomainOpB, 1); + } + // Avoid transpose on the output by swapping A and B. + OpOperand *operandA = &contractOp->getOpOperand(0); + OpOperand *operandB = &contractOp->getOpOperand(1); + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + std::swap(indexingMaps[0], indexingMaps[1]); + rewriter.modifyOpInPlace(contractOp, [&]() { + Value operandATmp = operandA->get(); + contractOp->setOperand(0, operandB->get()); + contractOp->setOperand(1, operandATmp); + contractOp.setIndexingMapsAttr( + ArrayAttr::get(contractOp.getContext(), + llvm::to_vector(llvm::map_range( + indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + })))); + }); + return contractOp; + } + + if (!isInnerMostDim(operandA, *minorKInCodomainOpA, contractOp, type, 0)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for A\n"); + assert(isInnerMostDim(operandA, *minorMInCodomainOpA, contractOp, type, 0)); + emitTransposeOnOperand(rewriter, contractOp, operandA->get(), + *minorKInCodomainOpA, *minorMInCodomainOpA, 0); + } + if (!isInnerMostDim(operandB, *minorNInCodomainOpB, contractOp, type, 1)) { + LLVM_DEBUG(llvm::dbgs() + << "[makeMinorDimensionsInnerMost] emit transpose for B\n"); + assert(isInnerMostDim(operandB, *minorKInCodomainOpB, contractOp, type, 1)); + emitTransposeOnOperand(rewriter, contractOp, operandB->get(), + *minorKInCodomainOpB, *minorNInCodomainOpB, 1); + } + return contractOp; +} + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain) { + for (size_t i = 0; i < operations.size(); i++) { + auto input = op->getOperand(i); + if (!operations[i](input.getDefiningOp())) + return false; + if (input.getDefiningOp()->getOperand(0).getDefiningOp() != nullptr) { + if (!(isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()) || + isa( + input.getDefiningOp()->getOperand(0).getDefiningOp()))) + return false; + } + inputs.push_back(input.getDefiningOp()->getOpOperand(0).get()); + opChain.push_back(input.getDefiningOp()); + } + return true; +} + +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain) { + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto use : op->getResult(0).getUsers()) { + if (use != op && operation(use)) { + if (!isa(use->getOperand(1).getDefiningOp())) + return false; + output.push_back(use->getOpOperand(1).get()); + opChain.push_back(use); + return true; + } + } + return false; +} + +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain) { + auto &block = region->front(); + + llvm::SmallSetVector chainedValues; + + auto start = block.begin(); + for (auto opItr = block.begin(); opItr != block.end(); opItr++) { + if (&*opItr != currentOp || !operations[0](&*opItr)) + continue; + start = opItr; + opChain.push_back(&*opItr); + break; + } + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto check : operations) { + Operation *innerOp = &*start; + // Must be right op in right order + if (start == block.end() || !check(innerOp)) + return false; + start++; + // At least one operand must come from args or a previous op + bool consumesValueFromChain = false; + if (chainedValues.empty()) { + consumesValueFromChain = true; + } else { + for (auto operand : innerOp->getOperands()) { + if (chainedValues.contains(operand)) { + chainedValues.remove(operand); + consumesValueFromChain = true; + } + } + } + + // Operation isn't in the chain + if (!consumesValueFromChain) + return false; + + for (auto ret : innerOp->getResults()) { + chainedValues.insert(ret); + } + } + return true; +} + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp) { + if (!(dyn_cast(transposeOp.getOperand().getType()).getRank() == + 2 && + dyn_cast(transposeOp.getResult().getType()).getRank() == + 2) || + !(isa(transposeOp->getParentOp()) && + dyn_cast(transposeOp->getParentOp()).getRank() == 2)) + return false; + return true; +} + +// Extract the operands to be used in the function call. For each memref operand +// extract the aligned pointer and the offset. +SmallVector getOperands(OpBuilder &builder, Location loc, + ValueRange operands, IntegerAttr dataTypeAttr, + std::optional outDataTypeAttr) { + SmallVector res; + IntegerType integer64 = IntegerType::get(builder.getContext(), 64); + res.push_back( + builder.create(loc, integer64, dataTypeAttr)); + if (outDataTypeAttr) + res.push_back( + builder.create(loc, integer64, *outDataTypeAttr)); + + for (Value operand : operands) { + auto memrefType = dyn_cast(operand.getType()); + if (!memrefType) { + res.push_back(operand); + continue; + } + auto [ptr, offset] = ::mlir::utils::getPtrAndOffset(builder, operand, loc); + res.push_back(ptr); + res.push_back(offset); + } + return res; +} + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands) { + SmallVector results; + for (Value operand : operands) { + Type operandType = operand.getType(); + if (auto memrefType = dyn_cast(operandType)) { + // TODO: non-POD will require an LLVMTypeConverter. + Type basePtrType = LLVM::LLVMPointerType::get(builder.getContext()); + results.push_back(basePtrType); + results.push_back(builder.getIndexType()); // offset + } else { + results.push_back(operand.getType()); + } + } + return results; +} + +int64_t getOredFlags(ArrayAttr flags) { + int64_t oredFlag = 0; + for (auto flag : flags) { + int64_t intAttr = dyn_cast(flag).getInt(); + // LIBXSMM is col-major, swap A and B flags. + if (auto gemmFlag = dyn_cast_or_null(flag)) { + if (gemmFlag.getValue() == GemmFlags::VNNI_A) + intAttr = static_cast(GemmFlags::VNNI_B); + if (gemmFlag.getValue() == GemmFlags::VNNI_B) + intAttr = static_cast(GemmFlags::VNNI_A); + } + oredFlag |= intAttr; + } + return oredFlag; +} + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName) { + auto libFnType = rewriter.getFunctionType( + dispatchOperandTypes, IntegerType::get(rewriter.getContext(), 64)); + + if (!module.lookupSymbol(fnName.getAttr())) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, fnName.getValue(), libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName.getValue(), IntegerType::get(rewriter.getContext(), 64), + dispatchOperands); + return call; +} + +func::CallOp buildInvokeCall(RewriterBase &rewriter, Location loc, + ModuleOp module, SmallVector operandRange, + StringRef invokeName, DataTypeAttr dtype, + std::optional outDtype) { + SmallVector operandTypes; + // Extra operands for datatypes. + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + operandTypes.push_back(integer64); + if (outDtype) + operandTypes.push_back(integer64); + operandTypes.append( + xsmm::utils::extractInvokeOperandTypes(rewriter, operandRange)); + auto libFnType = rewriter.getFunctionType(operandTypes, {}); + FlatSymbolRefAttr fnName = + SymbolRefAttr::get(rewriter.getContext(), invokeName); + + if (!module.lookupSymbol(fnName)) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, invokeName, libFnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create( + loc, fnName, TypeRange(), + xsmm::utils::getOperands(rewriter, loc, operandRange, dtype, outDtype)); + + return call; +} + +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + ArrayRef indexingMaps, + SmallVector flags) { + MLIRContext *ctx = op->getContext(); + Location loc = op->getLoc(); + + Type indexType = rewriter.getIndexType(); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + + auto posMNKBatch = *dimPositionsMNKBatch(indexingMaps); + unsigned posM = std::get<0>(posMNKBatch); + unsigned posN = std::get<1>(posMNKBatch); + unsigned posK = std::get<2>(posMNKBatch); + std::optional posBatch = std::get<3>(posMNKBatch); + + assert(inputs.size() == 3 && "Expect three inputs for BRGEMM call"); + Value A = inputs[0], B = inputs[1], C = inputs[2]; + + auto getMemrefMetadata = [&](Value operand) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + MemRefType baseMemrefType = + MemRefType::get({}, memrefType.getElementType()); + SmallVector sizesTypes(memrefType.getRank(), indexType); + SmallVector stridesTypes(memrefType.getRank(), indexType); + return rewriter.create( + loc, baseMemrefType, /*offsetType=*/indexType, sizesTypes, stridesTypes, + operand); + }; + + auto metadataA = getMemrefMetadata(A); + auto metadataB = getMemrefMetadata(B); + auto metadataC = getMemrefMetadata(C); + + auto posMInA = *getPosInCodomain(posM, indexingMaps[0], ctx); + auto posNInB = *getPosInCodomain(posN, indexingMaps[1], ctx); + auto posKInA = *getPosInCodomain(posK, indexingMaps[0], ctx); + auto posKInB = *getPosInCodomain(posK, indexingMaps[1], ctx); + + auto m = metadataA.getSizes()[posMInA]; + auto n = metadataB.getSizes()[posNInB]; + auto k = metadataA.getSizes()[posKInA]; + + auto posLeadingDimA = posMInA; // TODO: account for transposes... + auto lda = metadataA.getStrides()[posLeadingDimA]; + auto posLeadingDimB = posKInB; // TODO: account for transposes... + auto ldb = metadataB.getStrides()[posLeadingDimB]; + auto posLeadingDimC = *getPosInCodomain( + posM, indexingMaps[2], ctx); // TODO: account for transposes... + auto ldc = metadataC.getStrides()[posLeadingDimC]; + + Value strideA, strideB; + std::optional batchSize; + if (posBatch) { + auto posBatchInA = *getPosInCodomain(*posBatch, indexingMaps[0], ctx); + auto posBatchInB = *getPosInCodomain(*posBatch, indexingMaps[1], ctx); + batchSize = metadataA.getSizes()[posBatchInA]; + strideA = metadataA.getStrides()[posBatchInA]; + strideB = metadataB.getStrides()[posBatchInB]; + } + + auto dtype = xsmm::utils::getDataType(rewriter, inputs[0].getType()); + auto outDtype = xsmm::utils::getDataType(rewriter, inputs[2].getType()); + SmallVector dispatchOperands; + SmallVector dispatchOperandTypes; + // Dispatch the data type. + dispatchOperands.push_back(rewriter.create( + loc, integer64, cast(dtype))); + dispatchOperandTypes.push_back(integer64); + dispatchOperands.push_back(rewriter.create( + loc, integer64, cast(outDtype))); + dispatchOperandTypes.push_back(integer64); + + ArrayAttr brgemmFlags = rewriter.getArrayAttr(flags); + SmallVector invokeOperands; + std::string dispatchName = "xsmm_gemm_dispatch"; + std::string invokeName = "xsmm_gemm_invoke"; + if (posBatch) { + dispatchName = "xsmm_brgemm_dispatch"; + invokeName = "xsmm_brgemm_invoke"; + } + + auto sizesAndStrides = SmallVector{m, n, k, lda, ldb, ldc}; + if (posBatch) + sizesAndStrides.append({strideA, strideB}); + for (auto sizeOrStride : sizesAndStrides) { + auto sizeOrStrideInt64 = getValueOrCreateCastToIndexLike( + rewriter, op->getLoc(), integer64, sizeOrStride); + + dispatchOperands.push_back(sizeOrStrideInt64); + dispatchOperandTypes.push_back(integer64); + } + + // Dispatch the flags. Pass to the library the already ored-flag to + // avoid changing the interface every time we add a new flag. Flags + // are assumed to be verified before (i.e., op verifier). + int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags); + + dispatchOperands.push_back(rewriter.create( + loc, integer64, IntegerAttr::get(integer64, oredFlag))); + dispatchOperandTypes.push_back(integer64); + ModuleOp module = op->getParentOfType(); + auto dispatched = xsmm::utils::buildDispatchCall( + rewriter, loc, dispatchOperands, dispatchOperandTypes, module, + SymbolRefAttr::get(op->getContext(), dispatchName)); + SmallVector operandRange; + operandRange.push_back(dispatched.getResult(0)); + for (auto operand : inputs) + operandRange.push_back(operand); + if (posBatch) + operandRange.push_back(*batchSize); + auto invokeCall = xsmm::utils::buildInvokeCall( + rewriter, loc, module, operandRange, invokeName, dtype, outDtype); + return std::make_pair(&*dispatched, &*invokeCall); +} + +} // namespace utils +} // namespace xsmm +} // namespace mlir diff --git a/third_party/cpu/lib/Xsmm/XsmmUtils.h b/third_party/cpu/lib/Xsmm/XsmmUtils.h new file mode 100644 index 000000000000..65b54fd81817 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/XsmmUtils.h @@ -0,0 +1,162 @@ +//===- XsmmUtils.h - --------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_DIALECT_XSMM_XSMMUTILS_H +#define TPP_DIALECT_XSMM_XSMMUTILS_H + +#include "cpu/include/Xsmm/XsmmEnum.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include + +using namespace mlir; + +namespace mlir { +class Type; +class RewriterBase; +class ArrayAttr; +class Operation; +class ValueRange; +class Attribute; + +namespace xsmm { + +struct BrgemmInfo { + OpFoldResult m; + OpFoldResult n; + OpFoldResult k; + OpFoldResult batch; + + OpFoldResult lda; + OpFoldResult ldb; + OpFoldResult ldc; + OpFoldResult strideA; + OpFoldResult strideB; + + bool isVnni = false; +}; + +template +std::function FuncType = + [](Operation *op) { return isa(op); }; + +class UnaryKindAttr; + +struct UnaryInfo { + unsigned m; + unsigned n; + + int64_t ldi; + int64_t ldo; +}; + +struct BinaryInfo { + unsigned m; + unsigned n; + + int64_t ldiLhs; + int64_t ldiRhs; + int64_t ldo; +}; + +namespace utils { + +DataTypeAttr getDataType(RewriterBase &rewriter, Type type); + +FailureOr getUnaryInfo(Value input, Value output, + UnaryFlags inputFlag); + +FailureOr getBinaryInfo(Value lhs, BinaryFlags lhsFlag, Value rhs, + BinaryFlags rhsFlag, Value output); + +// Compute the broadcasting flags for 'inputType' based 'outputType'. +// Rules for broadcasting follows Numpy-style, and are only allowed in +// 'inputType'. see: https://numpy.org/doc/stable/user/basics.broadcasting.html +FailureOr getUnaryFlags(Type inputType, Type outputType); + +// Compute the broadcasting flags for 'operandType' based on 'outputType'. +enum class OperandPos { LHS = 0, RHS = 1 }; +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber); +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber); + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber); + +FailureOr getLeadingDim(Type type, size_t pos = 0); + +int64_t getOredFlags(ArrayAttr flags); + +SmallVector extractInvokeOperandTypes(OpBuilder &builder, + ValueRange operands); +SmallVector +getOperands(OpBuilder &builder, Location loc, ValueRange operands, + IntegerAttr dataTypeAttr, + std::optional outDataTypeAttr = std::nullopt); + +LogicalResult isMappableToBrgemm(PatternRewriter &rewriter, + Operation *contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap); + +FailureOr +makeMinorDimensionsInnerMost(RewriterBase &rewriter, + vector::ContractionOp contractOp, unsigned m, + unsigned n, unsigned k, IntegerAttr type); +std::optional getPosInCodomain(unsigned dim, AffineMap map, + MLIRContext *ctx); +FailureOr +checkAccess(PatternRewriter &rewriter, Operation *contractOp, unsigned m, + unsigned n, SmallVector kVector, + std::optional batchPos, SmallVector inputs, + ArrayRef indexingMap); + +bool WithInputs(PatternRewriter &rewriter, Operation *op, + SmallVector> operations, + SmallVector &inputs, SmallVector &opChain); +bool WithOutput(Operation *op, std::function operation, + SmallVector &output, SmallVector &opChain); +bool WithOps(Region *region, Operation *op, Operation *currentOp, + SmallVector> operations, + SmallVector &opChain); + +bool isTwoDTransposeOp(vector::TransposeOp transposeOp); + +func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, + ArrayRef dispatchOperands, + ArrayRef dispatchOperandTypes, + ModuleOp module, FlatSymbolRefAttr fnName); +func::CallOp +buildInvokeCall(RewriterBase &rewriter, Location loc, ModuleOp module, + SmallVector operands, StringRef invokeName, + DataTypeAttr dtype, + std::optional outDtype = std::nullopt); + +// Create a pair of XSMM dispatch and invoke (BR)GEMM calls. +std::pair +buildBrgemmCalls(PatternRewriter &rewriter, Operation *op, ValueRange inputs, + ArrayRef indexingMaps, + SmallVector flags); + +} // namespace utils +} // namespace xsmm +} // namespace mlir + +#endif // TPP_DIALECT_XSMM_XSMMUTILS_H diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp new file mode 100644 index 000000000000..c0bfe28e87d8 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.cpp @@ -0,0 +1,543 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "XsmmRunnerUtils.h" +#include "libxsmm.h" // NOLINT [build/include_subdir] +#include "libxsmm_utils.h" + +// Helper function prototypes. +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile = stderr); +static void printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmShape, + FILE *outfile = stderr); + +static bool hasImplicitComputeDtypeUnary(const libxsmm_meltw_unary_type dtype) { + switch (dtype) { + // Zero + case LIBXSMM_MELTW_TYPE_UNARY_XOR: + // Copy + case LIBXSMM_MELTW_TYPE_UNARY_IDENTITY: + // Transpose + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_NORMT: + // VNNI2 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI2_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI2_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD2: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD2: + // VNNI4 + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4T: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_NORM_TO_VNNI4_PAD: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADN_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_PADNM_MOD4: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_NORM: + case LIBXSMM_MELTW_TYPE_UNARY_TRANSFORM_VNNI4_TO_VNNI2: + return true; + default: + return false; + } +} + +namespace { + +void *get_base_ptr(const libxsmm_datatype dType, void *alignedPtr, + int64_t offset) { + if (dType == LIBXSMM_DATATYPE_F32) { + float *base_ptr = (float *)alignedPtr + offset; + return (void *)base_ptr; + } else if (dType == LIBXSMM_DATATYPE_BF16) { + bf16 *base_ptr = (bf16 *)alignedPtr + offset; + return (void *)base_ptr; + } else if (dType == LIBXSMM_DATATYPE_BF8) { + uint8_t *base_ptr = (uint8_t *)alignedPtr + offset; + return (void *)base_ptr; + } + fprintf(stderr, "Unhandled data type in get_data_pointer_from_memref_desc:%d\n", + dType); + return nullptr; +} + +} // namespace + +extern "C" void xsmm_gemm_invoke(const libxsmm_datatype dType, + const libxsmm_datatype out_dtype, int64_t addr, + void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, + void *alignedPtrC, int64_t offsetC) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(out_dtype, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_gemm_dispatch(const libxsmm_datatype dtype, + const libxsmm_datatype out_dtype, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "ldb: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + + // See: + // https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication + // LIBXSMM col-major change m with n. + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = out_dtype; + assert((out_dtype == LIBXSMM_DATATYPE_F32 || + out_dtype == LIBXSMM_DATATYPE_BF16 || + out_dtype == LIBXSMM_DATATYPE_BF8) && + "no support for selecting comp_type for non-F32/BF16/BF8 dtypes"); + // Libxsmm has limited support for comp_types w.r.t. A & B & C's dtype. + // F32 is supported in case of F32, BF16, and BF8 out_dtypes. + l_shape.comp_type = LIBXSMM_DATATYPE_F32; + + auto sgemm = libxsmm_dispatch_gemm(l_shape, l_flags, l_prefetch_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate matmul func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "out_dtype: %u\n", out_dtype); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" int64_t +xsmm_unary_dispatch(const libxsmm_meltw_unary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldi, int64_t ldo, + const libxsmm_meltw_unary_flags unary_flags) { + // std::cout << "ldi: " << ldi << "\n"; + // std::cout << "ldo: " << ldo << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "type: " << type << "\n"; + // std::cout << "bcast_type: " << bcast_type << "\n"; + + libxsmm_meltw_unary_shape unary_shape; + // Row major to col major swap m with n. + unary_shape.m = static_cast(n); + unary_shape.n = static_cast(m); + unary_shape.in0_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + // Copy and Zero should remain in BF16 to avoid useless up/down casts + auto force_fp32 = (dtype == LIBXSMM_DATATYPE_BF16 && + !hasImplicitComputeDtypeUnary(op_type)); + unary_shape.comp_type = force_fp32 ? LIBXSMM_DATATYPE_F32 : dtype; + unary_shape.out_type = dtype; + unary_shape.ldi = static_cast(ldi); + unary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_unary kernel = + libxsmm_dispatch_meltw_unary(op_type, unary_shape, unary_flags); + if (!kernel) { + fprintf(stderr, "failed to generate unary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", unary_flags); + printXsmmStruct(unary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t +xsmm_binary_dispatch(const libxsmm_meltw_binary_type op_type, + const libxsmm_datatype dtype, int64_t m, int64_t n, + int64_t ldiLhs, int64_t ldiRhs, int64_t ldo, + const libxsmm_meltw_binary_flags flags) { + libxsmm_meltw_binary_shape binary_shape; + // Row major to col major swap m with n. + binary_shape.m = static_cast(n); + binary_shape.n = static_cast(m); + binary_shape.in0_type = dtype; + binary_shape.in1_type = dtype; + // Retarget computation type from bf16 to f32 due to missing hardware support. + binary_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + binary_shape.out_type = dtype; + binary_shape.ldi = static_cast(ldiLhs); + binary_shape.ldi2 = static_cast(ldiRhs); + binary_shape.ldo = static_cast(ldo); + + libxsmm_meltwfunction_binary kernel = + libxsmm_dispatch_meltw_binary(op_type, binary_shape, flags); + if (!kernel) { + fprintf(stderr, "failed to generate binary func\n"); + fprintf(stderr, "op_type: %u\n", op_type); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(binary_shape); + exit(-1); + } + + return reinterpret_cast(kernel); +} + +extern "C" int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype dtype, int64_t m, int64_t n, int64_t k, int64_t lda, + int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_cfg_flags = flags; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb; + l_shape.ldb = lda; + l_shape.ldc = ldc; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = dtype; + l_shape.comp_type = + dtype == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : dtype; + + auto sgemm = libxsmm_dispatch_tilecfg_gemm(l_shape, l_cfg_flags); + if (!sgemm) { + fprintf(stderr, "failed to generate tileconfig func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "flags: %u\n", flags); + printXsmmStruct(l_shape); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_unary_param param; + + param.in.primary = get_base_ptr(dType, alignedPtrIn, offsetIn); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, + void *alignedPtrRhs, int64_t offsetRhs, + void *alignedPtrOut, int64_t offsetOut) { + libxsmm_meltw_binary_param param; + + param.in0.primary = get_base_ptr(dType, alignedPtrLhs, offsetLhs); + param.in1.primary = get_base_ptr(dType, alignedPtrRhs, offsetRhs); + param.out.primary = get_base_ptr(dType, alignedPtrOut, offsetOut); + + libxsmm_meltwfunction_binary kernel = + reinterpret_cast(addr); + kernel(¶m); +} + +extern "C" void xsmm_unary_scalar_invoke(const libxsmm_datatype dType, + int64_t addr, float input, + void *alignedOut, int64_t offsetOut) { + libxsmm_meltwfunction_unary kernel = + reinterpret_cast(addr); + libxsmm_meltw_unary_param param; + + param.in.primary = (void *)&input; + param.out.primary = get_base_ptr(dType, alignedOut, offsetOut); + kernel(¶m); +} + +extern "C" void xsmm_brgemm_invoke(const libxsmm_datatype dType, + const libxsmm_datatype out_dtype, + int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, + int64_t offsetC, int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(out_dtype, alignedPtrC, offsetC); + + sgemm.gemm = reinterpret_cast(addr); + sgemm.gemm(&gemm_param); +} + +extern "C" int64_t xsmm_brgemm_dispatch(const libxsmm_datatype dtype, + const libxsmm_datatype out_dtype, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, + int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags flags) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = flags; + libxsmm_bitfield l_prefetch_flags = 0; + libxsmm_gemm_batch_reduce_config l_brconfig; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = dtype; + l_shape.b_in_type = dtype; + l_shape.out_type = out_dtype; + assert((out_dtype == LIBXSMM_DATATYPE_F32 || + out_dtype == LIBXSMM_DATATYPE_BF16 || + out_dtype == LIBXSMM_DATATYPE_BF8) && + "no support for selecting comp_type for non-F32/BF16/BF8 dtypes"); + // Libxsmm has limited support for comp_types w.r.t. A & B & C's dtype. + // F32 is supported in case of F32, BF16, and BF8 out_dtypes. + l_shape.comp_type = LIBXSMM_DATATYPE_F32; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + + size_t typeSize; + if (dtype == LIBXSMM_DATATYPE_F32) + typeSize = sizeof(float); + else if (dtype == LIBXSMM_DATATYPE_BF16) + typeSize = sizeof(bf16); + else if (dtype == LIBXSMM_DATATYPE_BF8) + typeSize = sizeof(uint8_t); + else + assert(false && "unsupported datatype"); + + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + auto sgemm = + libxsmm_dispatch_brgemm(l_shape, l_flags, l_prefetch_flags, l_brconfig); + if (!sgemm) { + fprintf(stderr, "failed to generate brgemm func\n"); + fprintf(stderr, "dtype: %u\n", dtype); + fprintf(stderr, "out_dtype: %u\n", out_dtype); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" void xsmm_fused_brgemm_invoke(const libxsmm_datatype dType, + int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, + int64_t offsetD, int64_t numBatches) { + libxsmm_xmmfunction sgemm; + libxsmm_gemm_ext_param gemm_param; + + unsigned long long numBatchesVar = numBatches; + gemm_param.op.tertiary = (void *)&numBatchesVar; + + // LIBXSMM col-major change A with B. + gemm_param.a.primary = get_base_ptr(dType, alignedPtrB, offsetB); + gemm_param.b.primary = get_base_ptr(dType, alignedPtrA, offsetA); + gemm_param.c.primary = get_base_ptr(dType, alignedPtrC, offsetC); + gemm_param.d.primary = get_base_ptr(dType, alignedPtrD, offsetD); + + sgemm.gemm_ext = reinterpret_cast(addr); + sgemm.gemm_ext(&gemm_param); +} + +extern "C" int64_t +xsmm_fused_brgemm_dispatch(const libxsmm_datatype data_type, int64_t m, + int64_t n, int64_t k, int64_t lda, int64_t ldb, + int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type) { + // std::cout << "lda: " << lda << "\n"; + // std::cout << "lbd: " << ldb << "\n"; + // std::cout << "ldc: " << ldc << "\n"; + // std::cout << "m: " << m << "\n"; + // std::cout << "n: " << n << "\n"; + // std::cout << "k: " << k << "\n"; + + libxsmm_blasint lda_int = lda; + libxsmm_blasint ldb_int = ldb; + libxsmm_blasint ldc_int = ldc; + libxsmm_blasint m_int = m; + libxsmm_blasint n_int = n; + libxsmm_blasint k_int = k; + libxsmm_gemm_shape l_shape; + libxsmm_bitfield l_flags = gemm_flags; + libxsmm_bitfield l_prefetch_flags = 0; + + l_shape.m = n_int; + l_shape.n = m_int; + l_shape.k = k_int; + l_shape.lda = ldb_int; + l_shape.ldb = lda_int; + l_shape.ldc = ldc_int; + l_shape.a_in_type = data_type; + l_shape.b_in_type = data_type; + l_shape.out_type = data_type; + // Retarget computation type from bf16 to f32 due to missing hardware support. + l_shape.comp_type = + data_type == LIBXSMM_DATATYPE_BF16 ? LIBXSMM_DATATYPE_F32 : data_type; + + libxsmm_gemm_batch_reduce_config l_brconfig; + l_brconfig.br_type = LIBXSMM_GEMM_BATCH_REDUCE_STRIDE; + auto typeSize = + data_type == LIBXSMM_DATATYPE_F32 ? sizeof(float) : sizeof(bf16); + l_brconfig.br_stride_a_hint = stride_b * typeSize; + l_brconfig.br_stride_b_hint = stride_a * typeSize; + l_brconfig.br_unroll_hint = 0; + + libxsmm_gemm_ext_unary_argops l_argops; + memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops)); + l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; + l_argops.ldcp = ldc; + l_argops.cp_unary_type = unary_op_type; + + libxsmm_gemm_ext_binary_postops l_postops; + memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops)); + l_postops.d_in_type = data_type; + + l_postops.d_binary_flags = binary_flags; + l_postops.d_binary_type = binary_op_type; + l_postops.ldd = ldc; + + auto sgemm = libxsmm_dispatch_brgemm_ext(l_shape, l_flags, l_prefetch_flags, + l_brconfig, l_argops, l_postops); + if (!sgemm) { + fprintf(stderr, "failed to generate fused brgemm func\n"); + fprintf(stderr, "data_type: %u\n", data_type); + printXsmmStruct(l_shape); + printXsmmStruct(l_brconfig); + exit(-1); + } + + return reinterpret_cast(sgemm); +} + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *tileState, int64_t offset) { + libxsmm_xmmfunction cfg_tr; + + libxsmm_tilecfg_state *l_tilestate = + reinterpret_cast(tileState); + + cfg_tr.tilecfg = reinterpret_cast(addr); + cfg_tr.tilecfg(l_tilestate); +} + +static void printXsmmStruct(const libxsmm_gemm_shape &gemmShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", gemmShape.m); + fprintf(outfile, "N: %d\n", gemmShape.n); + fprintf(outfile, "K: %d\n", gemmShape.k); + fprintf(outfile, "lda: %d\n", gemmShape.lda); + fprintf(outfile, "ldb: %d\n", gemmShape.ldb); + fprintf(outfile, "ldc: %d\n", gemmShape.ldc); + fprintf(outfile, "a_in_type: %d\n", gemmShape.a_in_type); + fprintf(outfile, "b_in_type: %d\n", gemmShape.b_in_type); + fprintf(outfile, "comp_type: %d\n", gemmShape.comp_type); + fprintf(outfile, "out_type: %d\n", gemmShape.out_type); +} + +static void printXsmmStruct(const libxsmm_meltw_unary_shape &unaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", unaryShape.m); + fprintf(outfile, "N: %d\n", unaryShape.n); + fprintf(outfile, "in0_type: %d\n", unaryShape.in0_type); + fprintf(outfile, "comp_type: %d\n", unaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", unaryShape.out_type); + fprintf(outfile, "ldi: %d\n", unaryShape.ldi); + fprintf(outfile, "ldo: %d\n", unaryShape.ldo); +} + +static void printXsmmStruct(const libxsmm_meltw_binary_shape &binaryShape, + FILE *outfile) { + fprintf(outfile, "M: %d\n", binaryShape.m); + fprintf(outfile, "N: %d\n", binaryShape.n); + fprintf(outfile, "in0_type: %d\n", binaryShape.in0_type); + fprintf(outfile, "in1_type: %d\n", binaryShape.in1_type); + fprintf(outfile, "comp_type: %d\n", binaryShape.comp_type); + fprintf(outfile, "out_type: %d\n", binaryShape.out_type); + fprintf(outfile, "ldi: %d\n", binaryShape.ldi); + fprintf(outfile, "ldi2: %d\n", binaryShape.ldi2); + fprintf(outfile, "ldo: %d\n", binaryShape.ldo); +} + +static void +printXsmmStruct(const libxsmm_gemm_batch_reduce_config &brgemmConfig, + FILE *outfile) { + fprintf(outfile, "br_type: %d\n", brgemmConfig.br_type); + fprintf(outfile, "br_stride_a_hint: %d\n", brgemmConfig.br_stride_a_hint); + fprintf(outfile, "br_stride_b_hint: %d\n", brgemmConfig.br_stride_b_hint); + fprintf(outfile, "br_unroll_hint: %d\n", brgemmConfig.br_unroll_hint); +} diff --git a/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h new file mode 100644 index 000000000000..d2b03fb1db39 --- /dev/null +++ b/third_party/cpu/lib/Xsmm/runtime/XsmmRunnerUtils.h @@ -0,0 +1,85 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_EXECUTIONENGINE_CRUNNERUTILS_H +#define TPP_EXECUTIONENGINE_CRUNNERUTILS_H + +#include "libxsmm.h" +#include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/ExecutionEngine/RunnerUtils.h" + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_gemm_dispatch( + const libxsmm_datatype, const libxsmm_datatype, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_unary_dispatch( + const libxsmm_meltw_unary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, const libxsmm_meltw_unary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_binary_dispatch( + const libxsmm_meltw_binary_type, const libxsmm_datatype, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_meltw_binary_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_brgemm_dispatch( + const libxsmm_datatype, const libxsmm_datatype, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_fused_brgemm_dispatch( + const libxsmm_datatype data_type, int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, int64_t stride_a, int64_t stride_b, + const libxsmm_gemm_flags gemm_flags, + const libxsmm_meltw_unary_flags unary_flags, + const libxsmm_meltw_unary_type unary_op_type, + const libxsmm_meltw_binary_flags binary_flags, + const libxsmm_meltw_binary_type binary_op_type); + +extern "C" MLIR_RUNNERUTILS_EXPORT int64_t xsmm_intel_amx_tile_config_dispatch( + const libxsmm_datatype, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t, const libxsmm_gemm_flags); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_gemm_invoke(const libxsmm_datatype dType, const libxsmm_datatype out_dtype, + int64_t addr, void *alignedPtrA, int64_t offsetA, + void *alignedPtrB, int64_t offsetB, void *alignedPtrC, + int64_t offsetC); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrIn, int64_t offsetIn, void *alignedPtrOut, + int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_unary_scalar_invoke(const libxsmm_datatype, int64_t addr, float scalar, + void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_binary_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrLhs, int64_t offsetLhs, void *alignedPtrRhs, + int64_t offsetRhs, void *alignedPtrOut, int64_t offsetOut); + +extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_brgemm_invoke( + const libxsmm_datatype dType, const libxsmm_datatype out_dtype, + int64_t addr, void *alignedPtrA, int64_t offsetA, void *alignedPtrB, + int64_t offsetB, void *alignedPtrC, int64_t offsetC, int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void xsmm_fused_brgemm_invoke( + const libxsmm_datatype dType, int64_t addr, void *alignedPtrA, + int64_t offsetA, void *alignedPtrB, int64_t offsetB, void *alignedPtrC, + int64_t offsetC, void *alignedPtrD, int64_t offsetD, int64_t numBatches); + +extern "C" MLIR_RUNNERUTILS_EXPORT void +xsmm_intel_amx_tile_config_invoke(const libxsmm_datatype dType, int64_t addr, + void *alignedPtrA, int64_t offset); + +#endif // TPP_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/third_party/cpu/python/setup.py b/third_party/cpu/python/setup.py new file mode 100644 index 000000000000..c3c963b92daf --- /dev/null +++ b/third_party/cpu/python/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension +import os + +xsmm_root = os.getenv("XSMM_ROOT_DIR") +xsmm_lib = os.getenv("XSMM_LIB_DIR") +print(f'Using LIBXSMM root: {xsmm_root}') +print(f'LIBXSMM lib location: {xsmm_lib}') + +setup(name='xsmm_py', + ext_modules=[ + cpp_extension.CppExtension('xsmm_py', ['xsmm_utils.cpp'], + include_dirs=[ + f'{xsmm_root}/include', + f'{xsmm_root}/src/template' + ], + library_dirs=[f'{xsmm_lib}'], + libraries=['xsmm', 'omp'], + extra_compile_args=['-fopenmp'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/third_party/cpu/python/xsmm_utils.cpp b/third_party/cpu/python/xsmm_utils.cpp new file mode 100644 index 000000000000..368c7f830181 --- /dev/null +++ b/third_party/cpu/python/xsmm_utils.cpp @@ -0,0 +1,68 @@ +#include + +#include "libxsmm.h" +#include + +#include + +void fastZeroPad2D(const at::Tensor &input, torch::Tensor &output) { + const auto inSizes = input.sizes(); + const auto outSizes = output.sizes(); + const auto byteSize = input.element_size(); + assert(input.is_floating_point() && inSizes.size() == 2 || + outSizes.size() == 2 && outSizes[0] >= inSizes[0] && + outSizes[1] >= inSizes[1] && byteSize == output.element_size() && + "Invalid fastZeroPad2D tensors"); + + libxsmm_datatype dtype; + if (byteSize == 4) + dtype = LIBXSMM_DATATYPE_F32; + else if (byteSize == 2) + dtype = LIBXSMM_DATATYPE_BF16; + else if (byteSize == 1) + dtype = LIBXSMM_DATATYPE_BF8; + else + assert(false && "unsupported datatype"); + + libxsmm_meltw_unary_shape shape; + // Fliped to libxsmm's column-major convention. + shape.m = inSizes[1]; + shape.n = 1; + shape.ldi = inSizes[1]; + shape.ldo = outSizes[1]; + shape.in0_type = dtype; + shape.out_type = dtype; + shape.comp_type = dtype; + libxsmm_bitfield flags = LIBXSMM_MELTW_FLAG_UNARY_NONE; + libxsmm_meltwfunction_unary identityFn = libxsmm_dispatch_meltw_unary( + LIBXSMM_MELTW_TYPE_UNARY_IDENTITY, shape, flags); + + void *baseIn = input.data_ptr(); + void *outIn = output.data_ptr(); + const int padRight = outSizes[1] - inSizes[1]; + +#pragma omp parallel for schedule(static) + for (int i = 0; i < inSizes[0]; ++i) { +#if 0 + libxsmm_meltw_unary_param param; + param.in.primary = baseIn + i * inSizes[1] * byteSize; + param.out.primary = outIn + i * outSizes[1] * byteSize; + identityFn(¶m); +#else + std::memcpy( outIn + i * outSizes[1] * byteSize, baseIn + i * inSizes[1] * byteSize, inSizes[1]*byteSize ); +#endif + // Zero out right padding. + std::memset(outIn + i * outSizes[1] * byteSize + inSizes[1] * byteSize, 0, + byteSize * padRight); + } + + // Zero out bottom padding. +#pragma omp parallel for schedule(static) + for (int i = inSizes[0]; i < outSizes[0]; ++i) { + std::memset(outIn + i * outSizes[1] * byteSize, 0, byteSize * outSizes[1]); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fastZeroPad2D", &fastZeroPad2D, "Fast 2D tensor zero padding"); +} diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp new file mode 100644 index 000000000000..68b7efa78f01 --- /dev/null +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -0,0 +1,394 @@ +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include + +#if defined(_MSC_VER) +#define EXPORT __declspec(dllexport) +#elif defined(__GNUC__) +#define EXPORT __attribute__((visibility("default"))) +#else +#define EXPORT +#endif + +namespace { + +// A poor man's Torch-like pretty print for tensors and vectors. +const int MAX_FLOAT_WIDTH = 8; +const int FLOAT_PREC = 4; +const int ELEMS_PER_LINE = 8; + +using FLOAT16 = struct _FLOAT16 { +#ifdef FLT16_MAX + _Float16 x; +#else + uint16_t x; +#endif + + float toFloat32() const { +#ifdef FLT16_MAX + return static_cast(x); +#else + // Based on https://gist.github.com/zhuker/b4bd1fb306c7b04975b712c37c4c4075 + uint32_t t1; + uint32_t t2; + uint32_t t3; + + t1 = x & 0x7fffu; // Non-sign bits + t2 = x & 0x8000u; // Sign bit + t3 = x & 0x7c00u; // Exponent + + t1 <<= 13u; // Align mantissa on MSB + t2 <<= 16u; // Shift sign bit into position + + t1 += 0x38000000; // Adjust bias + + t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero + + t1 |= t2; // Re-insert sign bit + + float out; + *((uint32_t *)&out) = t1; + return out; +#endif + } +}; + +struct FormatInfo { + bool isInt; + bool isSigned; + int bitWidth; + int maxIntDigits; + bool hasNegative; + bool scientific; + bool isHex; +}; + +template struct RawMemRefDescriptor { + const T *allocated; + const T *aligned; + intptr_t offset; + intptr_t sizesAndStrides[]; +}; + +template class MemRefDescriptor { +private: + const T *data_; + std::vector sizes_; + std::vector strides_; + + MemRefDescriptor(const T *data, std::vector sizes, + std::vector strides) + : data_(data), sizes_(std::move(sizes)), strides_(std::move(strides)) {} + +public: + MemRefDescriptor(int32_t rank, void *rawDescriptor) { + auto *rawDesc = static_cast *>(rawDescriptor); + data_ = rawDesc->aligned + rawDesc->offset; + sizes_.insert(sizes_.begin(), rawDesc->sizesAndStrides, + rawDesc->sizesAndStrides + rank); + strides_.insert(strides_.begin(), rawDesc->sizesAndStrides + rank, + rawDesc->sizesAndStrides + rank * 2); + } + + const T *data() const { return data_; } + + int64_t rank() const { return static_cast(sizes_.size()); } + + int64_t size(int64_t dim) const { return sizes_[dim]; } + + int64_t stride(int64_t dim) const { return strides_[dim]; } + + MemRefDescriptor subView(int64_t idx) const { + assert(rank() > 1); + return {data_ + idx * stride(0), + {sizes_.begin() + 1, sizes_.end()}, + {strides_.begin() + 1, strides_.end()}}; + } +}; + +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +template +std::pair computeDigitInfo(T val) { + if (val == 0) + return {1, false}; + int digits = + std::max(static_cast(std::log10(val >= 0 ? val : -val)), 0) + 1; + return {digits, val < 0}; +} + +template <> +std::pair +computeDigitInfo(FLOAT16 val) { + return computeDigitInfo(val.toFloat32()); +} + +template +std::tuple computeDigitStats(const MemRefDescriptor &desc) { + int maxIntDigits = 0; + int minIntDigits = std::numeric_limits::max(); + bool hasNegative = false; + + if (desc.rank() == 1) { + const T *data = desc.data(); + int64_t stride = desc.stride(0); + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [digits, negative] = computeDigitInfo(data[i * stride]); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, digits); + minIntDigits = std::min(minIntDigits, digits); + } + } else { + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [maxDigits, minDigits, negative] = + computeDigitStats(desc.subView(i)); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, maxDigits); + minIntDigits = std::min(minIntDigits, minDigits); + } + } + + return std::make_tuple(maxIntDigits, minIntDigits, hasNegative); +} + +template +FormatInfo getFormatInfo(const MemRefDescriptor &desc, bool isInt, + bool isSigned, int32_t bitWidth, bool isHex) { + if (isHex) { + assert(bitWidth >= 8 && bitWidth <= 64 && bitWidth % 8 == 0); + return {isInt, isSigned, bitWidth, bitWidth / 4, false, false, true}; + } + auto [maxIntDigits, minIntDigits, hasNegative] = computeDigitStats(desc); + // Fallback to the scientific format for certain cases. + bool scientific; + if (isInt) { + scientific = false; + } else { + scientific = maxIntDigits + 2 + (hasNegative ? 1 : 0) > MAX_FLOAT_WIDTH; + scientific |= maxIntDigits - minIntDigits > 3; + } + return {isInt, isSigned, bitWidth, maxIntDigits, + hasNegative, scientific, false}; +} + +template +void printFormattedElement(std::stringstream &ss, T val, + const FormatInfo &formatInfo) { + // Right now, the GPU's hex float doesn't work correctly. C++ has std:: + // hexfloat, but let's consider only hex integers for now. + if (formatInfo.isHex && formatInfo.isInt) { + ss << "0x" << std::hex << std::setw(formatInfo.maxIntDigits) + << std::setfill('0') << val; + return; + } + + int padding = 0; + auto [digits, negative] = computeDigitInfo(val); + if (!negative && formatInfo.hasNegative) + padding++; + if (formatInfo.scientific) { + ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) + << std::setprecision(FLOAT_PREC) << std::string(padding, ' ') << val; + } else { + padding += formatInfo.maxIntDigits - digits; + ss << std::fixed << std::setprecision(FLOAT_PREC) + << std::string(padding, ' ') << val; + } +} + +// int8_t is printed as char, so use int16_t instead. +template <> +void printFormattedElement(std::stringstream &ss, int8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} + +template <> +void printFormattedElement(std::stringstream &ss, uint8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} + +template <> +void printFormattedElement(std::stringstream &ss, FLOAT16 val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val.toFloat32(), formatInfo); +} + +template +void printToStreamRecursive(const MemRefDescriptor &desc, + std::stringstream &ss, const FormatInfo &formatInfo, + const std::string &linePrefix) { + if (desc.rank() > 1) { + ss << "["; + for (int64_t i = 0; i < desc.size(0); ++i) { + printToStreamRecursive(desc.subView(i), ss, formatInfo, linePrefix + " "); + if (i != desc.size(0) - 1) + ss << ",\n" << linePrefix << " "; + } + ss << "]"; + return; + } + + const T *data = desc.data(); + int64_t stride = desc.stride(0); + int64_t numElems = desc.size(0); + + ss << "["; + if (numElems <= ELEMS_PER_LINE) { + for (int i = 0; i < numElems; i++) { + printFormattedElement(ss, data[i * stride], formatInfo); + if (i != numElems - 1) + ss << ", "; + } + } else { + // TODO: Too many lines? Omit the middle lines. + for (int i = 0; i < numElems; i++) { + printFormattedElement(ss, data[i * stride], formatInfo); + if (i == numElems - 1) + break; + if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { + ss << ",\n" << linePrefix << " "; + } else { + ss << ", "; + } + } + } + ss << "]"; +} + +template +void printToStream(const MemRefDescriptor &desc, std::stringstream &ss, + const FormatInfo &partialFormatInfo, + const std::string &linePrefix) { + FormatInfo formatInfo = getFormatInfo( + desc, partialFormatInfo.isInt, partialFormatInfo.isSigned, + partialFormatInfo.bitWidth, partialFormatInfo.isHex); + printToStreamRecursive(desc, ss, formatInfo, linePrefix); +} + +void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, + int32_t btw, bool isInteger, bool isSignedInteger, bool asHex, + const std::string &linePrefix) { + + FormatInfo partialFormat{.isInt = isInteger, + .isSigned = isSignedInteger, + .bitWidth = btw, + .isHex = asHex}; + if (!isInteger) { + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } + } + if (isSignedInteger) { + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 8: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 1: + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } + } + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 8: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; + case 1: + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } +} + +} // namespace + +extern "C" { + +EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, + const char *message, const char *file, int32_t line, + const char *function) { + if (cond) + return; + fprintf(stderr, "%s:%u: %s: block: [%u, %u, %u] Assertion `%s` failed.\n", + file, line, function, pid0, pid1, pid2, message); + abort(); +} + +// Print the pid prefix like the GPU and interpreter. And vectors are printed +// similar to Torch's printing like the following: +// (1, 0, 0) x: [ -0.4963, -1.7682, 2.0885, 3.1320, -4.3074, 5.6341, +// -6.4901, 7.8964, -8.4556, -9.6323, -10.3489, -11.4017, +// -12.0223, 13.1689, 14.2939, -15.5185] +EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, + int32_t pid2, const char *prefix, + UnrankedMemRefType memref, int32_t btw, + bool isInteger, bool isSigned, + bool asHex) { + std::stringstream ss; + ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix; + std::string linePrefix(ss.str().size(), ' '); + printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, isSigned, + asHex, linePrefix); + ss << "\n"; + std::cout << ss.str() << std::flush; +} + +} // extern "C" diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc new file mode 100644 index 000000000000..ca5b7f0be13b --- /dev/null +++ b/third_party/cpu/triton_cpu.cc @@ -0,0 +1,229 @@ +#include "ScalarizePass/ScalarizeInterfaceImpl.h" +#include "TritonCPUToLLVM/Passes.h" +#include "TritonCPUTransforms/Passes.h" +#include "TritonRaiseBlockPointer/Passes.h" +#include "TritonToTritonCPU/Passes.h" +#include "Xsmm/Passes.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/TargetSelect.h" + +#include +#include +#include + +#if defined(__x86_64__) || defined(__i386__) +#include +#endif +#include +#include + +namespace py = pybind11; + +void init_triton_cpu_passes_ttcpuir(py::module &&m) { + using namespace mlir::triton; + + py::enum_(m, "VecLib") + .value("libsleef", cpu::VecLib::Sleef) + .value("libmvec", cpu::VecLib::Mvec); + + m.def("add_scalarize", [](mlir::PassManager &pm, bool skip_gather_scatter) { + pm.addPass( + mlir::triton::cpu::createScalarizeUsingForOpPass(skip_gather_scatter)); + }); + m.def("add_raise_block_pointer", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createTritonRaiseBlockPointer()); + }); + m.def("add_convert_memory_ops", [](mlir::PassManager &pm, + bool use_gather_scatter) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps(use_gather_scatter)); + }); + m.def("add_convert_ptr_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertPtrOps()); + }); + m.def("add_convert_elementwise_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + }); + m.def("add_convert_elem_manip_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertElemManipOps()); + }); + m.def("add_convert_dot_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotOp()); + }); + m.def("add_convert_histogram_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); + }); + m.def("add_convert_reduction_op", + [](mlir::PassManager &pm, bool use_reduction_op, + bool use_multidim_reduction_op) { + pm.addPass(mlir::triton::cpu::createConvertReductionOp( + use_reduction_op, use_multidim_reduction_op)); + }); + m.def("add_convert_scan_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertScanOp()); + }); + m.def("add_convert_cf_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + }); + m.def("add_convert_atomic_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); + }); + m.def("add_convert_debug_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDebugOps()); + }); + m.def("add_triton_cpu_canonicalizer", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createCanonicalize()); + }); + m.def("add_optimize_masks", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createOptimizeMasks()); + }); + m.def("add_convert_dot_product", [](mlir::PassManager &pm, + bool useHorizontalSum) { + pm.addPass(mlir::triton::cpu::createConvertDotProduct(useHorizontalSum)); + }); + m.def("add_convert_dot_to_amx", [](mlir::PassManager &pm, bool convertInt8, + bool convertFp16, bool convertBf16) { + pm.addPass(mlir::triton::cpu::createConvertDotToAMX( + convertInt8, convertFp16, convertBf16)); + }); + m.def("add_convert_dot_to_fma", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotToFMA()); + }); + m.def("add_convert_dot_generic", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotGeneric()); + }); + m.def("add_convert_unsupported_ops", + [](mlir::PassManager &pm, bool promote_bf16_to_fp32, + bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { + pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps( + promote_bf16_to_fp32, convert_mixed_precision_matmul, + promote_lib_math_to_fp32)); + }); + m.def("add_decompose_fp_conversions", + [](mlir::PassManager &pm, bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + pm.addPass(mlir::triton::cpu::createDecomposeFpConversions( + decomposeBf16Conversions, decomposeFp8Conversions)); + }); + m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, + unsigned target_rank, bool lower_tensors) { + mlir::VectorTransferToSCFOptions opts; + opts.setTargetRank(target_rank); + opts.enableFullUnroll(full_unroll); + opts.enableLowerTensors(lower_tensors); + pm.addPass(mlir::createConvertVectorToSCFPass(opts)); + }); + m.def("add_lower_vector_multi_dim", [](mlir::PassManager &pm) { + pm.addNestedPass( + mlir::triton::cpu::createLowerMultiReductionPass()); + }); + m.def("add_func_op_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); + }); + m.def("add_program_id_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); + }); + m.def("add_memory_op_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + }); + m.def("add_atomic_ops_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); + }); + m.def("add_debug_ops_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); + }); + m.def("add_vector_to_llvmir", + [](mlir::PassManager &pm, bool reassoc_fp_reduction) { + mlir::ConvertVectorToLLVMPassOptions opts; + opts.reassociateFPReductions = reassoc_fp_reduction; + // opts.force32BitVectorIndices = true; + opts.amx = true; + // opts.armNeon = false; + // opts.armSVE = false; + opts.x86Vector = true; + pm.addPass(mlir::createConvertVectorToLLVMPass(opts)); + }); + m.def("add_lower_affine", [](mlir::PassManager &pm) { + pm.addPass(mlir::createLowerAffinePass()); + }); + m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + }); + m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib, + std::set cpu_features) { + pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib, cpu_features)); + }); + m.def("add_math_to_libm", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertMathToLibmPass()); + }); + m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertFuncToLLVMPass()); + }); + m.def("add_convert_vector_to_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertVectorToXsmm()); + }); + m.def("add_expand_strided_metadata", [](mlir::PassManager &pm) { + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + }); + m.def("add_convert_triton_to_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertTritonToXsmm()); + }); + m.def("add_loop_to_brgemm_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createLoopToBrgemmXsmm()); + }); +} + +void init_triton_cpu(py::module &&m) { + auto passes = m.def_submodule("passes"); + init_triton_cpu_passes_ttcpuir(passes.def_submodule("ttcpuir")); + + m.def("enable_amx", []() -> bool { +#if defined(__linux__) && defined(ARCH_REQ_XCOMP_PERM) + // AMX usage requires extended XSTATE which is disabled by default. We + // need to request access to AMX so that XSTATE was dynamically extended + // on the first AMX usage instead of issuing SIGILL. + // See https://www.kernel.org/doc/Documentation/x86/xstate.rst for more + // details. + constexpr int XFEATURE_XTILEDATA = 18; + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) + return false; + return true; +#else + return false; +#endif // __linux__ && ARCH_REQ_XCOMP_PERM + }); + + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); + mlir::registerAMXDialectTranslation(registry); + mlir::func::registerAllExtensions(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + m.def("find_kernel_names", [](mlir::ModuleOp &mod) { + std::vector res; + mod.walk([&](mlir::FunctionOpInterface funcOp) { + // Kernel functions are public and have a body. + if (!funcOp.getFunctionBody().empty() && + funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) + res.push_back(funcOp.getName().str()); + }); + return res; + }); +} diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt index 75f98fa8f73a..bab189bcbdd0 100644 --- a/third_party/nvidia/CMakeLists.txt +++ b/third_party/nvidia/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonNVIDIA ${CMAKE_CURRENT_SOURCE_DIR}/triton_nvidia.cc LINK_LIBS TritonNVIDIAGPUToLLVM NVGPUToLLVM) + target_link_libraries(TritonNVIDIA PRIVATE Python3::Module pybind11::headers) endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 36e73d6b882d..d94be93872de 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes, llvm, nvidia +from triton.runtime.errors import PTXASError from dataclasses import dataclass import functools @@ -12,6 +13,7 @@ import os import subprocess from pathlib import Path +import sysconfig def min_dot_size(target: GPUTarget): @@ -20,18 +22,19 @@ def min_dot_size(target: GPUTarget): @functools.lru_cache() def _path_to_binary(binary: str): + binary += sysconfig.get_config_var("EXE") paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(os.path.dirname(__file__), "bin", binary), ] - for bin in paths: - if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + for path in paths: + if os.path.exists(path) and os.path.isfile(path): + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: - return bin, version.group(1) + return path, version.group(1) raise RuntimeError(f"Cannot find {binary}") @@ -60,12 +63,17 @@ def ptx_get_version(cuda_version) -> int: raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version) -@functools.lru_cache() -def get_features(options): +def get_ptx_version_from_options(options): ptx_version = options.ptx_version if ptx_version is None: _, cuda_version = _path_to_binary("ptxas") ptx_version = ptx_get_version(cuda_version) + return ptx_version + + +@functools.lru_cache() +def get_features(options): + ptx_version = get_ptx_version_from_options(options) # PTX 8.3 is the max version supported by llvm 3a83162168. # @@ -181,8 +189,8 @@ def make_ttir(mod, metadata, opt): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_rewrite_tensor_pointer(pm) - passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) passes.common.add_licm(pm) @@ -222,9 +230,11 @@ def make_ttgir(mod, metadata, opt, capability): if capability // 10 >= 8: passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) passes.ttgpuir.add_reorder_instructions(pm) @@ -240,8 +250,10 @@ def make_ttgir(mod, metadata, opt, capability): @staticmethod def make_llir(src, metadata, options, capability): + ptx_version = get_ptx_version_from_options(options) + # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta") if num_warp_groups is not None: metadata["num_warps"] *= num_warp_groups mod = src @@ -258,7 +270,8 @@ def make_llir(src, metadata, options, capability): passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) - nvidia.passes.ttgpuir.add_to_llvmir(pm, capability) + passes.ttgpuir.add_allocate_global_scratch_memory(pm) + nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) @@ -291,7 +304,9 @@ def make_llir(src, metadata, options, capability): llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) # Get some metadata - metadata["shared"] = src.get_int_attr("triton_gpu.shared") + metadata["shared"] = src.get_int_attr("ttg.shared") + metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size") + metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment") ret = str(llvm_mod) del llvm_mod del context @@ -299,10 +314,7 @@ def make_llir(src, metadata, options, capability): @staticmethod def make_ptx(src, metadata, opt, capability): - ptx_version = opt.ptx_version - if ptx_version is None: - _, cuda_version = _path_to_binary("ptxas") - ptx_version = ptx_get_version(cuda_version) + ptx_version = get_ptx_version_from_options(opt) triple = 'nvptx64-nvidia-cuda' proc = 'sm_90a' if capability == 90 else f'sm_{capability}' @@ -357,9 +369,9 @@ def make_cubin(src, metadata, opt, capability): else: error = f'`ptxas` failed with error code {e.returncode}' - raise RuntimeError(f'{error}\n' - f'`ptxas` stderr:\n{log}\n' - f'Repro command: {" ".join(ptxas_cmd)}\n') + raise PTXASError(f"{error}\n" + f"`ptxas` stderr:\n{log}\n" + f'Repro command: {" ".join(ptxas_cmd)}\n') with open(fbin, 'rb') as f: cubin = f.read() diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index bb0d86888120..12deb0d1e7a3 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -3,7 +3,6 @@ #include #define PY_SSIZE_T_CLEAN #include -#include // Raises a Python exception and returns false if code is not CUDA_SUCCESS. static bool gpuAssert(CUresult code, const char *file, int line) { diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 38ce62b0c2a2..196f189caa4a 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,11 +1,13 @@ import functools import os +import sysconfig import hashlib import subprocess import tempfile from pathlib import Path from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager +from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver @@ -48,7 +50,8 @@ def library_dirs(): def compile_module_from_src(src, name): key = hashlib.sha256(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + cache_path = cache.get_file(f"{name}.{ext}") if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") @@ -56,7 +59,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) @@ -136,7 +139,7 @@ def format_of(ty): "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", @@ -144,7 +147,7 @@ def format_of(ty): }[ty] args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiKKOOOO" + args_format + format = "iiiKKOOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' internal_args_list = [] @@ -158,7 +161,8 @@ def format_of(ty): internal_args_list.append(f"_arg{i}") # generate glue code - params = [i for i in signature.keys() if i not in constants] + params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params.append("&global_scratch") src = f""" #include \"cuda.h\" #include @@ -205,8 +209,8 @@ def format_of(ty): return cuLaunchKernelExHandle; }} -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(params)} }}; if (gridX*gridY*gridZ > 0) {{ if (num_ctas == 1) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); @@ -275,6 +279,9 @@ def format_of(ty): PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; + }} else if (status != CUDA_SUCCESS) {{ + CUDA_CHECK(status); // Catch any other cuda API errors + ptr_info.valid = false; }} ptr_info.dev_ptr = dev_ptr; Py_DECREF(ret); // Thanks ChatGPT! @@ -331,7 +338,22 @@ def format_of(ty): return (CUtensorMap*)(ptr_as_uint); }} +static void ensureCudaContext() {{ + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) {{ + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + }} +}} + static PyObject* launch(PyObject* self, PyObject* args) {{ + // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes + ensureCudaContext(); + int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; @@ -339,10 +361,12 @@ def format_of(ty): PyObject *launch_exit_hook = NULL; PyObject *kernel_metadata = NULL; PyObject *launch_metadata = NULL; + PyObject *global_scratch_obj = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &_stream, &_function, &global_scratch_obj, &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {args_list})) {{ + &launch_enter_hook, &launch_exit_hook{args_list})) {{ return NULL; }} @@ -361,11 +385,20 @@ def format_of(ty): return NULL; }} + CUdeviceptr global_scratch = 0; + if (global_scratch_obj != Py_None) {{ + DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1); + if (!global_scratch_info.valid) {{ + return NULL; + }} + global_scratch = global_scratch_info.dev_ptr; + }} + // raise exception asap {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; @@ -380,9 +413,7 @@ def format_of(ty): }} - // return None - Py_INCREF(Py_None); - return Py_None; + Py_RETURN_NONE; }} static PyMethodDef ModuleMethods[] = {{ @@ -421,9 +452,17 @@ def __init__(self, src, metadata): src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch - - def __call__(self, *args, **kwargs): - self.launch(*args, **kwargs) + self.global_scratch_size = metadata.global_scratch_size + self.global_scratch_align = metadata.global_scratch_align + + def __call__(self, gridX, gridY, gridZ, stream, function, *args): + if self.global_scratch_size > 0: + grid_size = gridX * gridY * gridZ + alloc_size = grid_size * self.global_scratch_size + global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) + else: + global_scratch = None + self.launch(gridX, gridY, gridZ, stream, function, global_scratch, *args) class CudaDriver(GPUDriver): diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h index 30bfaea7d9eb..8cd8a180ca49 100644 --- a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h @@ -26,6 +26,8 @@ createDecomposeUnsupportedConversionsPass(); std::unique_ptr> createConvertTritonGPUToLLVMPass(); std::unique_ptr> createConvertTritonGPUToLLVMPass(int32_t computeCapability); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, int32_t ptxVersion); #define GEN_PASS_REGISTRATION #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h.inc" diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td index 07624c72d760..fd6f8e0a280f 100644 --- a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td @@ -20,7 +20,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::gpu::GPUDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect", - "mlir::tensor::TensorDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", @@ -30,6 +29,9 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, + Option<"ptxVersion", "ptx-version", + "int32_t", /*default*/"80", + "PTX version">, ]; } diff --git a/third_party/nvidia/language/cuda/_experimental_tma.py b/third_party/nvidia/language/cuda/_experimental_tma.py index 5677810194d9..94cc5355db23 100644 --- a/third_party/nvidia/language/cuda/_experimental_tma.py +++ b/third_party/nvidia/language/cuda/_experimental_tma.py @@ -29,7 +29,7 @@ def experimental_device_tensormap_create1d( load_size: core.tensor, global_size: core.tensor, element_ty: core.dtype, - _builder: ir.builder, + _builder: ir.builder = None, ): load_size = core._constexpr_to_value(load_size) global_size = semantic.to_tensor(global_size, _builder) @@ -58,7 +58,7 @@ def experimental_device_tensormap_create2d( load_size: Sequence[core.constexpr], global_size: Sequence[core.tensor], element_ty: core.dtype, - _builder: ir.builder, + _builder: ir.builder = None, ): assert len(load_size) == 2 assert len(global_size) == 2 @@ -104,5 +104,5 @@ def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size): @core.builtin -def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder): +def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder = None): semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 746b910e1e52..b6532836ebc1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -33,28 +33,6 @@ using namespace mlir; using namespace mlir::triton; namespace { -struct BarrierOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(mlir::gpu::BarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - if (op->hasAttr("bar_id")) { - // llvm.nvvm.barrier0 doesn't support bar_id and num_threads attributes, - // so we have to lower it to ptx manually. - auto barId = op->getAttrOfType("bar_id").getInt(); - auto numThreads = op->getAttrOfType("num_threads").getInt(); - barSync(rewriter, op, barId, numThreads); - rewriter.eraseOp(op); - return success(); - } - // Otherwise we let the default lowering handle it - return failure(); - } -}; - struct FenceAsyncSharedOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -193,7 +171,6 @@ struct WaitBarrierOpConversion void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index a944da1c83f1..91ddfc2700c8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,8 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp - DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp DotOpToLLVM/WGMMA.cpp DotOpToLLVM.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 54371d063fb1..d4613fef4321 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -10,7 +10,6 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -using mlir::isLayoutMmaV1; using ::mlir::LLVM::getMultiDimOffset; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getWrappedMultiDimOffset; @@ -18,29 +17,19 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; // Forward declarations -namespace SharedToDotOperandMMAv1 { - -Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, - Value thread, Location loc, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Type resultTy); - -} // namespace SharedToDotOperandMMAv1 - -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); -} +} // namespace SharedToDotOperandMMAv2OrV3 namespace { @@ -88,25 +77,20 @@ struct LocalLoadOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( + + if (isOuter) { + assert(false && "MMA Layout does not support outer product"); + return res; + } + + if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 + if (mmaLayout.isHopper()) + assert(dotOperandLayout.getOpIdx() == 0 && + "MMAv3 can only have operand $b on shared memory"); + + res = SharedToDotOperandMMAv2OrV3::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); - } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 - bool isMMAv1Row = mmaLayout.getMMAv1IsRow(dotOperandLayout.getOpIdx()); - auto srcSharedLayout = - cast(src.getType().getEncoding()); - - // Can only convert [1, 0] to row or [0, 1] to col for now - if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) || - (srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) { - llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n"; - return Value(); - } - - res = SharedToDotOperandMMAv1::convertLayout( - dotOperandLayout.getOpIdx(), src, smemObj, getThreadId(rewriter, loc), - loc, typeConverter, rewriter, dst.getType()); } else { assert(false && "Unsupported mma layout found"); } @@ -161,8 +145,6 @@ struct ConvertLayoutOpConversion dstLayout)) { if (shouldUseDistSmem(srcLayout, dstLayout)) return lowerDistToDistWithDistSmem(op, adaptor, rewriter, targetInfo); - if (isLayoutMmaV1(srcLayout) || isLayoutMmaV1(dstLayout)) - return lowerDistributedToDistributed(op, adaptor, rewriter, targetInfo); } if (isa(srcLayout) && isa(dstLayout)) { @@ -173,194 +155,6 @@ struct ConvertLayoutOpConversion } private: - // shared memory rd/st for blocked or mma layout with data padding - void processReplica(Location loc, ConversionPatternRewriter &rewriter, - bool stNotRd, RankedTensorType type, - ArrayRef numCTAsEachRep, - ArrayRef multiDimRepId, unsigned vec, - ArrayRef paddedRepShape, - ArrayRef origRepShape, - ArrayRef outOrd, SmallVector &vals, - Value smemBase) const { - auto accumNumCTAsEachRep = product(numCTAsEachRep); - auto layout = type.getEncoding(); - auto rank = type.getRank(); - auto sizePerThread = getSizePerThread(layout); - auto accumSizePerThread = product(sizePerThread); - SmallVector numCTATiles(rank); - auto shapePerCTATile = getShapePerCTATile(layout); - auto shapePerCTA = getShapePerCTA(layout, type.getShape()); - auto order = getOrder(layout); - for (unsigned d = 0; d < rank; ++d) { - numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); - } - auto elemTy = type.getElementType(); - bool isInt1 = elemTy.isInteger(1); - bool isPtr = isa(elemTy); - auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); - if (isInt1) - elemTy = IntegerType::get(elemTy.getContext(), 8); - else if (isPtr) - elemTy = IntegerType::get(elemTy.getContext(), 64); - - auto llvmElemTy = getTypeConverter()->convertType(elemTy); - - for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { - auto multiDimCTAInRepId = - getMultiDimIndex(ctaId, numCTAsEachRep, order); - SmallVector multiDimCTAId(rank); - for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { - auto d = it.index(); - multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); - } - - auto linearCTAId = - getLinearIndex(multiDimCTAId, numCTATiles, order); - // TODO: This is actually redundant index calculation, we should - // consider of caching the index calculation result in case - // of performance issue observed. - for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, - multiDimCTAInRepId, shapePerCTATile); - SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( - rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, - shapePerCTA); - Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, - paddedRepShape, outOrd); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); - auto vecTy = vec_ty(llvmElemTy, vec); - ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { - auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; - if (isInt1) - currVal = zext(llvmElemTy, currVal); - else if (isPtr) - currVal = ptrtoint(llvmElemTy, currVal); - valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); - } - store(valVec, ptr); - } else { - Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); - if (isInt1) - currVal = icmp_ne(currVal, - rewriter.create( - loc, i8_ty, rewriter.getI8IntegerAttr(0))); - else if (isPtr) - currVal = inttoptr(llvmElemTyOrig, currVal); - vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; - } - } - } - } - } - - // The MMAV1's result is quite different from the existing "Replica" - // structure, add a new simple but clear implementation for it to avoid - // modifying the logic of the existing one. - void processReplicaForMMAV1(Location loc, ConversionPatternRewriter &rewriter, - bool stNotRd, RankedTensorType type, - ArrayRef multiDimRepId, unsigned vec, - ArrayRef paddedRepShape, - ArrayRef outOrd, - SmallVector &vals, Value smemBase, - ArrayRef shape, - bool isDestMma = false) const { - unsigned accumNumCTAsEachRep = 1; - auto typeConverter = getTypeConverter(); - auto layout = type.getEncoding(); - NvidiaMmaEncodingAttr mma = dyn_cast(layout); - auto sliceLayout = dyn_cast(layout); - if (sliceLayout) - mma = cast(sliceLayout.getParent()); - - auto order = getOrder(layout); - auto rank = type.getRank(); - int accumSizePerThread = vals.size(); - - SmallVector numCTAs(rank, 1); - SmallVector numCTAsEachRep(rank, 1); - SmallVector shapePerCTATile = getShapePerCTATile(layout, shape); - SmallVector shapePerCTA = getShapePerCTA(layout, shape); - auto elemTy = typeConverter->convertType(type.getElementType()); - - int ctaId = 0; - - auto multiDimCTAInRepId = - getMultiDimIndex(ctaId, numCTAsEachRep, order); - SmallVector multiDimCTAId(rank); - for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { - auto d = it.index(); - multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); - } - - std::vector, Value>> coord2valT( - accumSizePerThread); - bool needTrans = outOrd[0] != 0; - if (sliceLayout || isDestMma) - needTrans = false; - - vec = needTrans ? 2 : 1; - { - // We need to transpose the coordinates and values here to enable vec=2 - // when store to smem. - std::vector, Value>> coord2val( - accumSizePerThread); - for (unsigned elemId = 0; elemId < accumSizePerThread; ++elemId) { - // TODO[Superjomn]: Move the coordinate computation out of loop, it is - // duplicate in Volta. - SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, - multiDimCTAInRepId, shapePerCTATile); - coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]); - } - - if (needTrans) { - // do transpose - int numM = mma.getMMAv1NumOuter(shapePerCTA, 0); - int numN = accumSizePerThread / numM; - - for (int r = 0; r < numM; r++) { - for (int c = 0; c < numN; c++) { - coord2valT[r * numN + c] = std::move(coord2val[c * numM + r]); - } - } - } else { - coord2valT = std::move(coord2val); - } - } - - // Now the coord2valT has the transposed and contiguous elements(with - // vec=2), the original vals is not needed. - for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - auto coord = coord2valT[elemId].first; - Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - Value ptr = gep(elemPtrTy, elemTy, smemBase, offset); - auto vecTy = vec_ty(elemTy, vec); - ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); - if (stNotRd) { - Value valVec = undef(vecTy); - for (unsigned v = 0; v < vec; ++v) { - auto currVal = coord2valT[elemId + v].second; - valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); - } - store(valVec, ptr); - } else { - Value valVec = load(vecTy, ptr); - for (unsigned v = 0; v < vec; ++v) { - Value currVal = extract_element(elemTy, valVec, i32_val(v)); - vals[elemId + v] = currVal; - } - } - } - } - LogicalResult lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -447,96 +241,6 @@ struct ConvertLayoutOpConversion return success(); } - // blocked/mma -> blocked/mma. - // Data padding in shared memory to avoid bank conflict. - LogicalResult - lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) const { - auto loc = op.getLoc(); - auto typeConverter = getTypeConverter(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - auto shape = dstTy.getShape(); - unsigned rank = dstTy.getRank(); - SmallVector numReplicates(rank); - SmallVector inNumCTAsEachRep(rank); - SmallVector outNumCTAsEachRep(rank); - SmallVector inNumCTAs(rank); - SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); - auto shapePerCTA = getShapePerCTA(srcLayout, shape); - - for (unsigned d = 0; d < rank; ++d) { - unsigned inPerCTA = - std::min(shapePerCTA[d], srcShapePerCTATile[d]); - unsigned outPerCTA = - std::min(shapePerCTA[d], dstShapePerCTATile[d]); - unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); - numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); - inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; - outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; - assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); - inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); - outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); - } - // Potentially we need to store for multiple CTAs in this replication - auto accumNumReplicates = product(numReplicates); - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto scratchConfig = - getScratchConfigForCvt(op.getSrc().getType(), op.getType()); - unsigned inVec = scratchConfig.inVec; - unsigned outVec = scratchConfig.outVec; - const auto &origRepShape = scratchConfig.repShape; - const auto &paddedRepShape = scratchConfig.paddedRepShape; - - unsigned outElems = getTotalElemsPerThread(dstTy); - auto outOrd = getOrder(dstLayout); - SmallVector outVals(outElems); - - for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { - auto multiDimRepId = - getMultiDimIndex(repId, numReplicates, outOrd); - if (repId != 0) { - barrier(); - } - - if (isLayoutMmaV1(srcLayout)) - processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy, - multiDimRepId, inVec, paddedRepShape, outOrd, - vals, smemBase, shape); - else - processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, - multiDimRepId, inVec, paddedRepShape, origRepShape, - outOrd, vals, smemBase); - - barrier(); - - if (isLayoutMmaV1(dstLayout)) - processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy, - multiDimRepId, outVec, paddedRepShape, outOrd, - outVals, smemBase, shape, /*isDestMma=*/true); - else - processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, - outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, - origRepShape, outOrd, outVals, smemBase); - } - - Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); - rewriter.replaceOp(op, result); - - return success(); - } - // Convert from accumulator MMA layout to 8bit dot operand layout. // The conversion logic is taken from: // https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45 @@ -631,64 +335,6 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } - - if (isMmaToDotShortcut(srcTy, dstTy)) { - // get source values - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned elems = getTotalElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - SmallVector vecVals; - SmallVector types; - // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer - // instructions to pack & unpack sub-word integers. A workaround is to - // store the results of ldmatrix in i32 - auto elemSize = elemTy.getIntOrFloatBitWidth(); - if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { - auto fold = 32 / elemSize; - for (unsigned i = 0; i < elems; i += fold) { - Value val = i32_val(0); - for (unsigned j = 0; j < fold; j++) { - auto ext = - shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); - val = or_(i32_ty, val, ext); - } - vecVals.push_back(val); - } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); - } else { - unsigned vecSize = std::max(32 / elemSize, 1); - Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); - } - } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); - } - - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } return failure(); } @@ -734,11 +380,12 @@ struct LocalAllocOpConversion SmallVector shape = convertType(srcTy.getShape()); auto order = sharedLayout.getOrder(); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, + swizzleByteSize)) { + return failure(); + } auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, shape, order, swizzleByteSize); - if (!layout.has_value()) - return failure(); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); auto smemPtrTy = ptr_ty(ctx, 3); @@ -748,23 +395,22 @@ struct LocalAllocOpConversion auto kBlock = str_attr("block"); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout->getInDimSize(kLane)); + Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); Value laneId = urem(threadId, threadsPerWarp); Value warpId = udiv(threadId, threadsPerWarp); - auto regBase = applyLinearLayout(loc, rewriter, *layout, + auto regBase = applyLinearLayout(loc, rewriter, layout, {{kRegister, i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, {kBlock, i32_val(0)}})[0] .second; auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto srcVec = layout->getNumConsecutiveInOut(); + auto srcVec = layout.getNumConsecutiveInOut(); Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); for (int i = 0; i < srcVals.size(); i += srcVec) { auto regIdx = - layout - ->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] .second; Value offset = xor_(regBase, i32_val(regIdx)); auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); @@ -813,9 +459,6 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { // For now give ConvertLayoutOpConversion higher benefit, I can split before // merging - // - // TODO(jlebar): lowerDistributedToDistributed does not get hit in any - // testcases. Is this dead code? Does the benefit need to be increased? patterns.add(typeConverter, targetInfo, benefit); // Same default benefit patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp deleted file mode 100644 index 6847c0550cbb..000000000000 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp +++ /dev/null @@ -1,354 +0,0 @@ -#include "Utility.h" - -using CoordTy = SmallVector; -using ValueTable = std::map, std::pair>; -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; -using ::mlir::LLVM::getStridesFromShapeAndOrder; -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getContigPerThread; -using ::mlir::triton::gpu::getOrder; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getSizePerThread; -using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::SharedEncodingAttr; - -// Compute the offset of the matrix to load. -// Returns offsetAM, offsetAK, offsetBN, offsetBK. -// NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at -// the same time in the usage in convert_layout[shared->dot_op], we leave -// the noexist info to be 0 and only use the desired argument from the -// composed result. In this way we want to retain the original code -// structure in convert_mma884 method for easier debugging. -static std::tuple -computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, - ArrayRef spw, ArrayRef rep, - ConversionPatternRewriter &rewriter, Location loc, - Type resultTy) { - auto *ctx = rewriter.getContext(); - auto wpt = cast( - cast( - cast(resultTy).getEncoding()) - .getParent()) - .getWarpsPerCTA(); - - Value _1 = i32_val(1); - Value _3 = i32_val(3); - Value _4 = i32_val(4); - Value _16 = i32_val(16); - Value _32 = i32_val(32); - - Value lane = urem(threadId, _32); - Value warp = udiv(threadId, _32); - - // warp offset - Value warp0 = urem(warp, i32_val(wpt[0])); - Value warp12 = udiv(warp, i32_val(wpt[0])); - Value warp1 = urem(warp12, i32_val(wpt[1])); - Value warpMOff = mul(warp0, i32_val(spw[0])); - Value warpNOff = mul(warp1, i32_val(spw[1])); - // Quad offset - Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0])); - Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1])); - // Pair offset - Value pairMOff = udiv(urem(lane, _16), _4); - pairMOff = urem(pairMOff, i32_val(fpw[0])); - pairMOff = mul(pairMOff, _4); - Value pairNOff = udiv(urem(lane, _16), _4); - pairNOff = udiv(pairNOff, i32_val(fpw[0])); - pairNOff = urem(pairNOff, i32_val(fpw[1])); - pairNOff = mul(pairNOff, _4); - // scale - pairMOff = mul(pairMOff, i32_val(rep[0] / 2)); - quadMOff = mul(quadMOff, i32_val(rep[0] / 2)); - pairNOff = mul(pairNOff, i32_val(rep[1] / 2)); - quadNOff = mul(quadNOff, i32_val(rep[1] / 2)); - // Quad pair offset - Value laneMOff = add(pairMOff, quadMOff); - Value laneNOff = add(pairNOff, quadNOff); - // A offset - Value offsetAM = add(warpMOff, laneMOff); - Value offsetAK = and_(lane, _3); - // B offset - Value offsetBN = add(warpNOff, laneNOff); - Value offsetBK = and_(lane, _3); - // i indices - Value offsetCM = add(and_(lane, _1), offsetAM); - if (isARow) { - offsetAM = add(offsetAM, urem(threadId, _4)); - offsetAK = i32_val(0); - } - if (!isBRow) { - offsetBN = add(offsetBN, urem(threadId, _4)); - offsetBK = i32_val(0); - } - - return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK); -} - -static Value loadA(Value tensor, const SharedMemoryObject &smemObj, - Value thread, Location loc, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Type resultTy) { - static constexpr std::array fpw{{2, 2, 1}}; - auto mmaEncoding = cast( - cast( - cast(resultTy).getEncoding()) - .getParent()); - auto wpt = mmaEncoding.getWarpsPerCTA(); - - auto *ctx = rewriter.getContext(); - auto tensorTy = cast(tensor.getType()); - auto sharedLayout = cast(tensorTy.getEncoding()); - auto shape = tensorTy.getShape(); - auto order = sharedLayout.getOrder(); - - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); - - bool isARow = order[0] != 0; - auto resultEncoding = cast( - cast(resultTy).getEncoding()); - auto [offsetAM, offsetAK, _3, _4] = computeOffsets( - thread, isARow, false, fpw, - mmaEncoding.getMMAv1ShapePerWarp(resultEncoding.getOpIdx()), - mmaEncoding.getMMAv1Rep(resultEncoding.getOpIdx()), rewriter, loc, - resultTy); - - int vecA = sharedLayout.getVec(); - - auto strides = smemObj.strides; - Value strideAM = isARow ? strides[0] : i32_val(1); - Value strideAK = isARow ? i32_val(1) : strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; - Value strideA1 = isARow ? strideAM : strideAK; - - int strideRepM = wpt[0] * fpw[0] * 8; - int strideRepK = 1; - - // swizzling - int perPhaseA = sharedLayout.getPerPhase(); - int maxPhaseA = sharedLayout.getMaxPhase(); - int stepA0 = isARow ? strideRepK : strideRepM; - int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1); - int NK = shape[1]; - - // pre-compute pointer lanes - Value offA0 = isARow ? offsetAK : offsetAM; - Value offA1 = isARow ? offsetAM : offsetAK; - Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); - offA0 = add(offA0, cSwizzleOffset); - SmallVector offA(numPtrA); - for (int i = 0; i < numPtrA; i++) { - Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); - offA0I = udiv(offA0I, i32_val(vecA)); - offA0I = xor_(offA0I, phaseA); - offA0I = mul(offA0I, i32_val(vecA)); - offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); - } - - Type elemX2Ty = vec_ty(f16_ty, 2); - Type elemTy = f16_ty; - if (tensorTy.getElementType().isBF16()) { - elemX2Ty = vec_ty(bf16_ty, 2); - elemTy = bf16_ty; - } - - // prepare arguments - SmallVector ptrA(numPtrA); - - std::map, std::pair> has; - for (int i = 0; i < numPtrA; i++) - ptrA[i] = gep(ptr_ty(ctx, 3), f16_ty, smemBase, offA[i]); - - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - auto loadA = [&](int m, int k) { - int offidx = (isARow ? k / 4 : m) % numPtrA; - Value thePtrA = gep(ptr_ty(ctx, 3), elemTy, smemBase, offA[offidx]); - - int stepAM = isARow ? m : m / numPtrA * numPtrA; - int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; - Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), - mul(i32_val(stepAK), strideAK)); - Value pa = gep(ptr_ty(ctx, 3), elemTy, thePtrA, offset); - Type vecTy = vec_ty(i32_ty, std::max(vecA / 2, 1)); - Type aPtrTy = ptr_ty(ctx, 3); - Value ha = load(vecTy, bitcast(pa, aPtrTy)); - // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty); - ld(has, m, k, ha00, ha01); - - if (vecA > 4) { - Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty); - Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty); - if (isARow) - ld(has, m, k + 4, ha10, ha11); - else - ld(has, m + 1, k, ha10, ha11); - } - }; - - bool isARow_ = mmaEncoding.getMMAv1IsRow(resultEncoding.getOpIdx()); - bool isAVec4 = mmaEncoding.getMMAv1IsVec4(resultEncoding.getOpIdx()); - unsigned numM = - mmaEncoding.getMMAv1NumOuter(shape, resultEncoding.getOpIdx()); - for (unsigned k = 0; k < NK; k += 4) - for (unsigned m = 0; m < numM / 2; ++m) - if (!has.count({m, k})) - loadA(m, k); - - SmallVector elems; - elems.reserve(has.size() * 2); - for (auto item : has) { // has is a map, the key should be ordered. - elems.push_back(bitcast(item.second.first, i32_ty)); - elems.push_back(bitcast(item.second.second, i32_ty)); - } - - Value res = packLLElements(loc, typeConverter, elems, rewriter, resultTy); - return res; -} - -static Value loadB(Value tensor, const SharedMemoryObject &smemObj, - Value thread, Location loc, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Type resultTy) { - static constexpr std::array fpw{{2, 2, 1}}; - auto mmaEncoding = cast( - cast( - cast(resultTy).getEncoding()) - .getParent()); - auto wpt = mmaEncoding.getWarpsPerCTA(); - // smem - auto strides = smemObj.strides; - - auto *ctx = rewriter.getContext(); - auto tensorTy = cast(tensor.getType()); - auto sharedLayout = cast(tensorTy.getEncoding()); - - auto shape = tensorTy.getShape(); - auto order = sharedLayout.getOrder(); - - Value smem = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); - bool isBRow = order[0] != 0; // is row-major in shared memory layout - // isBRow_ indicates whether B is row-major in DotOperand layout - auto resultEncoding = cast( - cast(resultTy).getEncoding()); - - int vecB = sharedLayout.getVec(); - Value strideBN = isBRow ? i32_val(1) : strides[1]; - Value strideBK = isBRow ? strides[0] : i32_val(1); - Value strideB0 = isBRow ? strideBN : strideBK; - Value strideB1 = isBRow ? strideBK : strideBN; - int strideRepN = wpt[1] * fpw[1] * 8; - int strideRepK = 1; - - auto [_3, _4, offsetBN, offsetBK] = computeOffsets( - thread, false, isBRow, fpw, - mmaEncoding.getMMAv1ShapePerWarp(resultEncoding.getOpIdx()), - mmaEncoding.getMMAv1Rep(resultEncoding.getOpIdx()), rewriter, loc, - resultTy); - - // swizzling - int perPhaseB = sharedLayout.getPerPhase(); - int maxPhaseB = sharedLayout.getMaxPhase(); - int stepB0 = isBRow ? strideRepN : strideRepK; - int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); - int NK = shape[0]; - - Value offB0 = isBRow ? offsetBN : offsetBK; - Value offB1 = isBRow ? offsetBK : offsetBN; - Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - - offB0 = add(offB0, cSwizzleOffset); - SmallVector offB(numPtrB); - for (int i = 0; i < numPtrB; ++i) { - Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); - offB0I = udiv(offB0I, i32_val(vecB)); - offB0I = xor_(offB0I, phaseB); - offB0I = mul(offB0I, i32_val(vecB)); - offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); - } - - Type elemTy = f16_ty; - Type elemX2Ty = vec_ty(f16_ty, 2); - if (tensorTy.getElementType().isBF16()) { - elemTy = bf16_ty; - elemX2Ty = vec_ty(bf16_ty, 2); - } - - SmallVector ptrB(numPtrB); - ValueTable hbs; - for (int i = 0; i < numPtrB; ++i) - ptrB[i] = gep(ptr_ty(ctx, 3), f16_ty, smem, offB[i]); - - auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - - auto loadB = [&](int n, int K) { - int offidx = (isBRow ? n : K / 4) % numPtrB; - Value thePtrB = ptrB[offidx]; - - int stepBN = isBRow ? n / numPtrB * numPtrB : n; - int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); - Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), - mul(i32_val(stepBK), strideBK)); - Value pb = gep(ptr_ty(ctx, 3), elemTy, thePtrB, offset); - - Type vecTy = vec_ty(i32_ty, std::max(vecB / 2, 1)); - Value hb = load(vecTy, bitcast(pb, ptr_ty(ctx, 3))); - // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty); - ld(hbs, n, K, hb00, hb01); - if (vecB > 4) { - Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty); - Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty); - if (isBRow) - ld(hbs, n + 1, K, hb10, hb11); - else - ld(hbs, n, K + 4, hb10, hb11); - } - }; - - bool isBRow_ = mmaEncoding.getMMAv1IsRow(resultEncoding.getOpIdx()); - assert(isBRow == isBRow_ && "B need smem isRow"); - bool isBVec4 = mmaEncoding.getMMAv1IsVec4(resultEncoding.getOpIdx()); - unsigned numN = - mmaEncoding.getMMAv1NumOuter(shape, resultEncoding.getOpIdx()); - for (unsigned k = 0; k < NK; k += 4) - for (unsigned n = 0; n < numN / 2; ++n) { - if (!hbs.count({n, k})) - loadB(n, k); - } - - SmallVector elems; - for (auto &item : hbs) { // has is a map, the key should be ordered. - elems.push_back(bitcast(item.second.first, i32_ty)); - elems.push_back(bitcast(item.second.second, i32_ty)); - } - - Value res = packLLElements(loc, typeConverter, elems, rewriter, resultTy); - return res; -} - -namespace SharedToDotOperandMMAv1 { - -Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, - Value thread, Location loc, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, Type resultTy) { - if (opIdx == 0) - return loadA(tensor, smemObj, thread, loc, typeConverter, rewriter, - resultTy); - else { - assert(opIdx == 1); - return loadB(tensor, smemObj, thread, loc, typeConverter, rewriter, - resultTy); - } -} - -} // namespace SharedToDotOperandMMAv1 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp similarity index 84% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp rename to third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 73c21cae6de2..b9aac96cbf91 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -14,6 +14,7 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; // Data loader for mma.16816 instruction. @@ -25,6 +26,7 @@ class MMA16816SmemLoader { ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, + int mmaElemBytes, bool isHopper, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc); @@ -67,6 +69,8 @@ class MMA16816SmemLoader { int perPhase; int maxPhase; int elemBytes; + int mmaElemBytes; + bool isHopper; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -203,10 +207,10 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // vecWidth // <-------> // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ -// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height -// ... | -// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 ... t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ // --------------------------------------------- || -------------------------------------------- // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 // t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 @@ -223,11 +227,6 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, Value cSwizzleOffset) { Value warpB = multiDimWarpId[0]; Value warpOff = kOrder == 2 ? multiDimWarpId[1] : multiDimWarpId[2]; - int cTileShape = tileShape[order[0]]; - int sTileShape = tileShape[order[1]]; - if (!needTrans) { - std::swap(cTileShape, sTileShape); - } SmallVector offs(numPtrs); @@ -236,7 +235,6 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value lane, int laneHeight = 8; int quadWidth = laneWidth * kWidth; int quadHeight = laneHeight; - int numQuadI = 2; // outer index base Value iBase = udiv(lane, i32_val(laneWidth)); @@ -364,6 +362,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { // base pointers + // ptrs[k][...] holds `vec` pointers each for (quadK == k) std::array, 2> ptrs; for (int i = 0; i < vecWidth; i++) ptrs[0][i] = getPtr(ptrIdx + i); @@ -383,11 +382,13 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } + // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) SmallVector> vptrs(4, SmallVector(vecWidth)); + // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); @@ -402,7 +403,10 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + // Hopper may not contain 32b contiguously along k-dimension + int kBits = isHopper ? (8 * elemBytes * kWidth) : 32; + int vecSize = kBits / canonBits; + retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; @@ -421,8 +425,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (isActualTrans) std::swap(retElems[1], retElems[2]); - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + + auto iTy = isHopper ? int_ty(kBits) : i32_ty; + + return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), + bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; } } @@ -432,8 +439,9 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, - const LLVMTypeConverter *typeConverter, const Location &loc) + int elemBytes, int mmaElemBytes, bool isHopper, + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, + const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), @@ -441,17 +449,29 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), - rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), + loc(loc), ctx(rewriter.getContext()) { + // If the current elemType width is different from the MMA elemType width, + // i.e. width-changing casting is done later in DotOp Layout... then, in the + // case of Hopper, the number of bytes held by each thread after loading will + // no longer be 32B. Hence this flag is required to stipulate different logic. + bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; stridedSmemOffset = smemStrides[order[1]]; smemBatchOffset = smemStrides[order[2]]; - vecWidth = 4 / elemBytes; + if (isHopperWidthChange) { + vecWidth = 4 / mmaElemBytes; + } else { + vecWidth = 4 / elemBytes; + } // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; nonKOrder = (kOrder == 2) ? 1 : 2; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); + canUseLdmatrix = canUseLdmatrix && !isHopperWidthChange; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, @@ -505,24 +525,60 @@ Type getSharedMemTy(Type argType) { } Value composeValuesToDotOperandLayoutStruct( - const ValueTable &vals, int batch, int n0, int n1, + const ValueTable &vals, int batch, int repOuter, int repK, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, Type eltTy, int kWidth, bool isHopper, + bool isA) { + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + assert(32 >= bitwidth && "only support 32-bit or less"); + auto numElemsPerVec = isHopper ? kWidth : 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + // FIXME: [DOT LL] + // `kWidth` specifies the number of contiguous elements each thread will load. + // Loaded elements are packed into a vector of int32, which will then be + // unpacked into individual elements. + // `kIters` specifies the number of contiguous int32 elements each thread + // should load. + // `kSize` specifies the total number of int32 elements each thread should + // load. + int kIters = isHopper ? 1 : kWidth / (32 / bitwidth); + int kSize = repK >= kIters ? repK * 2 : kIters; + std::vector elems; - for (int b = 0; b < batch; ++b) - for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + auto unpackVec = [&](int b, int m, int k) { + for (int kIter = 0; kIter < kIters; ++kIter) { + auto val = vals.at({b, m, (k + kIter) % kSize}); + auto vec = bitcast(val, vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + elems.push_back(extract_element(eltTy, vec, i32_val(i))); } + } + }; + + // Loading A tile is different from loading B tile since each tile of A is + // 16x16 while B is 16x8. + if (isA) { + for (int b = 0; b < batch; ++b) + for (int m = 0; m < repOuter; ++m) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, 2 * m, kIters * 2 * k); + unpackVec(b, 2 * m + 1, kIters * 2 * k); + unpackVec(b, 2 * m, kIters * (2 * k + 1)); + unpackVec(b, 2 * m + 1, kIters * (2 * k + 1)); + } + } else { + for (int b = 0; b < batch; ++b) + for (int n = 0; n < repOuter; ++n) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, n, kIters * 2 * k); + unpackVec(b, n, kIters * (2 * k + 1)); + } + } assert(!elems.empty()); - Type elemTy = elems[0].getType(); - MLIRContext *ctx = elemTy.getContext(); + MLIRContext *ctx = eltTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), elemTy)); + ctx, SmallVector(elems.size(), eltTy)); auto result = packLLElements(loc, typeConverter, elems, rewriter, structTy); return result; } @@ -544,18 +600,20 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, const int maxPhase = sharedLayout.getMaxPhase(); const int vecPhase = sharedLayout.getVec(); const int elemBytes = descTy.getElementTypeBitWidth() / 8; + const int mmaElemBytes = 4 / kWidth; + const bool isHopper = mmaLayout.getVersionMajor() == 3; auto order = sharedLayout.getOrder(); int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); - // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { - MMA16816SmemLoader loader( - nPerWarp, warpsPerTile, sharedLayout.getOrder(), - mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, - shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); + MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, + smemObj.strides, shapePerCTA /*tileShape*/, + instrShape, matShape, multiDimWarpId, perPhase, + maxPhase, elemBytes, mmaElemBytes, isHopper, + rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -573,6 +631,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto [ha0, ha1, ha2, ha3] = loader.loadX4( batch, (kOrder == 2) ? a : b /*mat0*/, (kOrder == 2) ? b : a /*mat1*/, ptrs, matTy, getSharedMemTy(eltTy)); + if (!isA) std::swap(ha1, ha2); // the following is incorrect @@ -595,28 +654,32 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto mmaLayout = mlir::cast(encoding.getParent()); + bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - auto mmaLayout = mlir::cast(encoding.getParent()); + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should + // add up to 32. We use kWidth to compute the element bitwidth of the input to + // WGMMA, which could be different from `bitwidth` due to later casting. + int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; + int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; + int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; int kWidth = encoding.getKWidth(); - auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, - encoding.getOpIdx()); + auto numRep = mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, kWidth, + encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); + auto warpOrder = mmaLayout.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA, order); + delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; - auto rank = shapePerCTA.size(); Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); if (isA) @@ -651,8 +714,10 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, loadFn(b, 2 * m, 2 * k); // Format the values to LLVM::Struct to passing to mma codegen. + Type eltTy = typeConverter->convertType(descTy.getElementType()); return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); + vals, numRepBatch, isA ? numRep[1] : numRep[2], numRepK, typeConverter, + loc, rewriter, eltTy, kWidth, isHopper, isA); } template @@ -764,7 +829,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -785,4 +850,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv2OrV3 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index cf0ddc248dd1..e8ddf871045e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -50,7 +50,7 @@ class DecomposeLocalLoadToDotOperand blockEncoding); Value load = rewriter.create(op.getLoc(), tmpType, op.getSrc()); - auto newSharedDescTy = triton::MemDescType::get( + auto newSharedDescTy = MemDescType::get( type.getShape(), type.getElementType(), triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), @@ -70,10 +70,15 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { + // FIXME [Dot LL] + // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after + // we have enabled the new layout conversion for all the cases. + auto nvidiaShortCutFn = [&](RankedTensorType srcTy, + RankedTensorType dstTy) { return true; }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMmaToDotShortcut); + nvidiaShortCutFn); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 3e915a577c54..76b2984126c8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -11,10 +11,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter); - LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); @@ -48,8 +44,6 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { NvidiaMmaEncodingAttr mmaLayout = dyn_cast( cast(D.getType()).getEncoding()); if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { - if (mmaLayout.isVolta()) - return convertMMA884(op, adaptor, getTypeConverter(), rewriter); if (mmaLayout.isTuring()) return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); if (mmaLayout.isAmpere()) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv1.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv1.cpp deleted file mode 100644 index 9d40f10729a6..000000000000 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv1.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" - -#include "Utility.h" - -using namespace mlir; -using namespace mlir::triton; - -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; - -using ValueTable = std::map, std::pair>; - -static Type getMmaRetType(TensorType operand) { - auto *ctx = operand.getContext(); - Type fp32Ty = type::f32Ty(ctx); - // f16*f16+f32->f32 - return struct_ty(SmallVector{8, fp32Ty}); -} - -static ValueTable extractLoadedOperand(Value llStruct, int NK, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter *typeConverter, - Type type) { - ValueTable rcds; - SmallVector elems = - unpackLLElements(llStruct.getLoc(), llStruct, rewriter); - - int offset = 0; - for (int i = 0; offset < elems.size(); ++i) { - for (int k = 0; k < NK; k += 4) { - rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); - offset += 2; - } - } - - return rcds; -} - -LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto *ctx = op.getContext(); - auto loc = op.getLoc(); - - Value A = op.getA(); - Value B = op.getB(); - Value D = op.getResult(); - auto mmaLayout = cast( - cast(D.getType()).getEncoding()); - auto ALayout = cast( - cast(A.getType()).getEncoding()); - auto BLayout = cast( - cast(B.getType()).getEncoding()); - - auto ATensorTy = cast(A.getType()); - auto BTensorTy = cast(B.getType()); - auto DTensorTy = cast(D.getType()); - auto AShape = ATensorTy.getShape(); - auto BShape = BTensorTy.getShape(); - - bool isARow = mmaLayout.getMMAv1IsRow(ALayout.getOpIdx()); - bool isBRow = mmaLayout.getMMAv1IsRow(BLayout.getOpIdx()); - auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] = - mmaLayout.decodeVoltaLayoutStates(); - assert(isARow == isARow_); - assert(isBRow == isBRow_); - - unsigned numM = mmaLayout.getMMAv1NumOuter(AShape, ALayout.getOpIdx()); - unsigned numN = mmaLayout.getMMAv1NumOuter(BShape, BLayout.getOpIdx()); - unsigned NK = AShape[1]; - - auto has = extractLoadedOperand(adaptor.getA(), NK, rewriter, typeConverter, - ATensorTy); - auto hbs = extractLoadedOperand(adaptor.getB(), NK, rewriter, typeConverter, - BTensorTy); - - // Initialize accumulators with external values, the acc holds the - // accumulator value that is shared between the MMA instructions inside a - // DotOp, we can call the order of the values the accumulator-internal - // order. - SmallVector acc = unpackLLElements(loc, adaptor.getC(), rewriter); - size_t resSize = acc.size(); - - // The resVals holds the final result of the DotOp. - // NOTE The current order of resVals is different from acc, we call it the - // accumulator-external order. and - SmallVector resVals(resSize); - - auto getIdx = [&](int m, int n) { - std::vector idx{{ - (m * 2 + 0) + (n * 4 + 0) * numM, // row0 - (m * 2 + 0) + (n * 4 + 1) * numM, - (m * 2 + 1) + (n * 4 + 0) * numM, // row1 - (m * 2 + 1) + (n * 4 + 1) * numM, - (m * 2 + 0) + (n * 4 + 2) * numM, // row2 - (m * 2 + 0) + (n * 4 + 3) * numM, - (m * 2 + 1) + (n * 4 + 2) * numM, // row3 - (m * 2 + 1) + (n * 4 + 3) * numM, - }}; - return idx; - }; - - auto callMMA = [&](unsigned m, unsigned n, unsigned k) { - auto ha = has.at({m, k}); - auto hb = hbs.at({n, k}); - - PTXBuilder builder; - auto idx = getIdx(m, n); - - // note: using "=f" for float leads to cleaner PTX - bool isIntMMA = DTensorTy.getElementType().isInteger(32); - auto *resOprs = builder.newListOperand(8, isIntMMA ? "=r" : "=f"); - auto *AOprs = builder.newListOperand({ - {ha.first, "r"}, - {ha.second, "r"}, - }); - - auto *BOprs = builder.newListOperand({ - {hb.first, "r"}, - {hb.second, "r"}, - }); - auto *COprs = builder.newListOperand(); - for (int i = 0; i < 8; ++i) - COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i))); - - auto mma = builder.create("mma.sync.aligned.m8n8k4") - ->o(isARow ? "row" : "col") - .o(isBRow ? "row" : "col") - .o("f32.f16.f16.f32"); - - mma(resOprs, AOprs, BOprs, COprs); - - Value res = builder.launch(rewriter, loc, getMmaRetType(ATensorTy)); - - for (auto i = 0; i < 8; i++) { - Value elem = extract_val(f32_ty, res, i); - acc[idx[i]] = elem; - } - }; - - for (unsigned k = 0; k < NK; k += 4) - for (unsigned m = 0; m < numM / 2; ++m) - for (unsigned n = 0; n < numN / 2; ++n) { - callMMA(m, n, k); - } - - // res holds the same layout of acc - for (size_t i = 0; i < acc.size(); ++i) { - resVals[i] = acc[i]; - } - - Value res = packLLElements(loc, typeConverter, resVals, rewriter, DTensorTy); - rewriter.replaceOp(op, res); - return success(); -} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 79ccb57206ae..0f7613c51f2a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -9,6 +9,7 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrderForDotOperand; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ValueTableV2 = std::map, Value>; @@ -59,26 +60,151 @@ Value loadC(Value tensor, Value llTensor, ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter, Value value, int batch, int n0, int n1, - RankedTensorType type) { + ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, + int repK, RankedTensorType type) { auto elems = unpackLLElements(loc, value, rewriter); + auto eltTy = typeConverter->convertType(type.getElementType()); int offset{}; ValueTableV2 vals; + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + + auto packVec = [&](std::array dstIdx) { + Value vec = undef(vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i)); + } + vals[dstIdx] = bitcast(vec, i32_ty); + offset += numElemsPerVec; + }; - // FIXME [Dot LL] - // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 auto dot = cast(type.getEncoding()); - auto largeK = dot.getKWidth() == 8 && - cast(dot.getParent()).isAmpere(); + auto kWidth = dot.getKWidth(); + auto largeK = bitwidth * kWidth > 32; if (largeK) { + // For layouts with a large K dimension, the original register layout needs + // to be divided into multiple MMAs, where each MMA has contiguous 32 bits + // along the K dimension per thread. + // Using kWidth = 8 and bitwidth = 2 as an example, + // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the + // K dimension. llvm::SmallVector si; + auto kIters = kWidth / (32 / bitwidth); - // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K if (dot.getOpIdx() == 0) { - si = llvm::SmallVector{0, 8, 4, 12, 1, 9, 5, 13, - 2, 10, 6, 14, 3, 11, 7, 15}; + // Original register layout: + // + // [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23] + // [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31] + // + // Each element in the layout is a single bf16. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]] + // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]] + // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]] + // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]] + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 4; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32, so numElemsPerVec=1. + // Each tile of the dot operand layout has a size of 16x32. + // However, if the triton tensor size is 16x16, elements along the k + // dimension are duplicated. Within each tile, each register + // contains 2x8 elements arranged as follows: + // + // tile0/0 tile0/1 + // |<--kWidth=4-->| |<--kWidth-->| + // |<-mmaWidth=2->| + // [0, 1, 2, 3] [0, 1, 2, 3] + // [4, 5, 6, 7] [4, 5, 6, 7] + // + // tile0/1 replicates the elements in tile0/0 along the k dimension. + // For a tensor size of 32x32, the next tile on the m dimension is as + // follows: + // + // tile1/0 tile1/1 + // |<--kWidth-->| |<--kWidth-->| + // [8, 9, 10, 11], [8, 9, 10, 11] + // [12, 13, 14, 15], [12, 13, 14, 15] + // + // Within a single tile, we can perform two MMAs, and the + // resulting register layout for each MMA is as follows: + // + // 1st MMA: [0, 4, 1, 5] + // 2nd MMA: [2, 6, 3, 7] + // 3rd MMA: [8, 12, 9, 13] + // 4th MMA: [10, 14, 11, 15] + // + // Additionally, we should reorder the elements by moving the duplicated + // elements to the end. In the example above, we convert the order from + // tile0/0, tile0/1, tile1/0, tile1/1 to tile0/0, tile1/0, tile0/1, + // tile1/1, so that only the first two tiles will be used in the + // computation. + size_t elemsPerTile = 2 * 2 * kWidth; + size_t elemsPerMma = 2 * 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t mTile = 0; mTile < 2; ++mTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + mTile * kWidth + + kTile * numElemsPerVec + e); + } + } } else { - si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + // Original register layout: + // + // [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [[0, 1], [8, 9]] + // 2nd MMA: [[2, 3], [10, 11]] + // 3rd MMA: [[4, 5], [12, 13]] + // 4th MMA: [[6, 7], [14, 15]] + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 2; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32. + // Original register layout: + // + // tile0/0 tile0/1 + // [0, 1, 2, 3]^T, [0, 1, 2, 3]^T + // + // Similar to the opIdx=0 situation, we should reorder the elements by + // moving the duplicated elements to the end. + size_t elemsPerTile = 2 * kWidth; + size_t elemsPerMma = 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + kTile * numElemsPerVec + + e); + } + } } auto step = si.size(); @@ -89,34 +215,25 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } std::copy(perm.begin(), perm.end(), elems.begin() + i * step); } - - if (dot.getOpIdx() == 1) { - // there are kWidth * 2 elems packed as bf16x2 - int elemsInTile = dot.getKWidth(); - // n0 and n1 are unrolled in the legacy path - // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no - // sense IMO - n0 *= 2; - n1 *= 2; - for (auto b = 0; b < batch; ++b) - for (auto j = 0; j < n1 / elemsInTile; ++j) - for (auto i = 0; i < n0; ++i) - for (auto k = 0; k < elemsInTile; ++k) { - vals[{b, i, elemsInTile * j + k}] = elems[offset++]; - } - return vals; - } } - for (auto b = 0; b < batch; ++b) - for (auto i = 0; i < n0; ++i) { - for (auto j = 0; j < n1; j++) { - vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; - } - } + if (dot.getOpIdx() == 0) { + for (auto b = 0; b < batch; ++b) + for (auto m = 0; m < repOuter; ++m) + for (auto k = 0; k < repK; ++k) { + packVec({b, 2 * m, 2 * k}); + packVec({b, 2 * m + 1, 2 * k}); + packVec({b, 2 * m, 2 * k + 1}); + packVec({b, 2 * m + 1, 2 * k + 1}); + } + } else { + for (auto b = 0; b < batch; ++b) + for (auto n = 0; n < repOuter; ++n) + for (auto k = 0; k < repK; ++k) { + packVec({b, n, 2 * k}); + packVec({b, n, 2 * k + 1}); + } + } return vals; } @@ -363,29 +480,34 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); - auto repA = cast(dotOpA.getParent()) - .getMMAv2RepForOperand(aShapePerCTA, bitwidth, - dotOpA.getKWidth(), dotOpA.getOpIdx()); + int kWidth = dotOpA.getKWidth(); + auto repA = + cast(dotOpA.getParent()) + .getRepForOperand(aShapePerCTA, bitwidth, kWidth, dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); - auto repB = cast(dotOpB.getParent()) - .getMMAv2RepForOperand(bShapePerCTA, bitwidth, - dotOpB.getKWidth(), dotOpB.getOpIdx()); + auto repB = + cast(dotOpB.getParent()) + .getRepForOperand(bShapePerCTA, bitwidth, kWidth, dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); int repM = repA[1], repN = repB[2], repK = repA[2]; int repBatch = repA[0]; + // We can reuse the same iteration order in + // getValuesFromDotOperandLayoutStruct as both a and b are K-major + assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), + aShapePerCTA.size(), + /*kMajor=*/true)); auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); - // FIXME [Dot LL] - // max(repN / 2, 1) is wrong for repN = 1! - // This is also wrong in - // NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand + assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), + bShapePerCTA.size(), + /*kMajor=*/true)); auto hb = getValuesFromDotOperandLayoutStruct( - typeConverter, loc, rewriter, loadedB, repBatch, std::max(repN / 2, 1), - repK, bTensorTy); + typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); + auto fc = unpackLLElements(loc, loadedC, rewriter); auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; int numCPackedElem = 4 / numMmaRets; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e046..85f7da2cb5b3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -30,6 +30,7 @@ using namespace mlir::triton; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; @@ -47,7 +48,7 @@ triton::nvgpu::WGMMAEltType getMmaRetType(Value d) { } triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { - auto aTy = cast(a.getType()).getElementType(); + auto aTy = cast(a.getType()).getElementType(); if (aTy.isF16()) { return triton::nvgpu::WGMMAEltType::f16; } else if (aTy.isBF16()) { @@ -197,7 +198,7 @@ DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, const NvidiaMmaEncodingAttr &mmaEncoding, Value tensor, Value smemObjBase, Value thread) { - auto aTy = cast(tensor.getType()); + auto aTy = cast(tensor.getType()); auto aSharedLayout = dyn_cast(aTy.getEncoding()); assert(aSharedLayout && "only support load dot operand from shared."); auto instrShape = mmaEncoding.getInstrShape(); @@ -264,6 +265,28 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. +// +// `elements` contains all loaded register values for operand A. +// This consists of operand A for possibly multiple wgmma instructions. +// For each wgmma, each warp in a warp group feeds a single "warp matrix" +// Each warp matrix consists of 2x2 "quads". +// Each thread holds several elements in each quad. Right before a wgmma, +// the sum of bitwidth of +// the elements in each quad should add up to 32. +// +// These values are stored unrolled in `elements`. +// The ordering of dimensions is as follows: +// batch (only 1 batch for Hopper currently) +// matM (m-index of the "warp matrix") +// matK (k-index of the "warp matrix") +// quadK (k-index of the "quad" in the core matrix) +// quadM (m-index of the "quad" in the core matrix) +// vecIdx (index of the element in the quad; this is always along the k-dim) +// +// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. +// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. +// Thus, both lowerings must obey this above ordering for the below code to be +// correct. llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, @@ -356,8 +379,8 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, Value loadedC, bool allowTF32, bool needsPartialAccumulator, uint32_t maxNumImpreciseAcc, bool sync, Value thread) { - auto aTensorTy = cast(a.getType()); - auto bTensorTy = cast(b.getType()); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); auto aSharedLayout = dyn_cast(aTensorTy.getEncoding()); auto bSharedLayout = cast(bTensorTy.getEncoding()); @@ -442,6 +465,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, if (aSharedLayout) { a = aLoader.smemLoad(m, k, rewriter, loc); } else { + auto aDotOpEnc = + cast(aTensorTy.getEncoding()); + assert(aDotOpEnc.getKWidth() == + 32 / aTensorTy.getElementTypeBitWidth()); + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; llvm::SmallVector regA = loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 760ba75d9816..d2cef405ebdf 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -8,6 +8,7 @@ #include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; @@ -24,76 +25,57 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace { // Return the mask for the unique data accessed by given tensor type. -// Used to mask out the redundant data accessed by threads. -Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, - Location loc, const NVIDIA::TargetInfo &targetInfo) { +// NOTE: Redundant memory load is allowed in triton, but redundant memory store +// is not allowed. +// mask = true: thread can write +// mask = false: thread should not write +Value getRedundantDataMask(ModuleOp moduleOp, Type valueTy, + ConversionPatternRewriter &rewriter, Location loc, + int regIdx, const NVIDIA::TargetInfo &targetInfo) { + auto ctx = moduleOp.getContext(); auto tensorTy = dyn_cast(valueTy); - Value mask = int_val(1, 1); + auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); auto tid = tid_val(); - auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); + auto mask = true_val(); + auto kReg = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); if (tensorTy) { - auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); - unsigned rank = shape.size(); - auto sizePerThread = triton::gpu::getSizePerThread(layout); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); - auto warpOrder = triton::gpu::getWarpOrder(layout); - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); - Value warpSize = i32_val(32); - Value laneId = urem(tid, warpSize); - Value warpId = udiv(tid, warpSize); - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); - SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); - for (unsigned dim = 0; dim < rank; ++dim) { - // if there is no data replication across threads on this dimension - if (shape[dim] >= shapePerCTATile[dim]) - continue; - // Otherwise, we need to mask threads that will replicate data on this - // dimension. Calculate the thread index on this dimension for the CTA - Value threadDim = - add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])), - multiDimThreadId[dim]); - mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), - i32_val(shape[dim]))); - } - // Do not write duplicated data when multicast is enabled - if (triton::gpu::getNumCTAs(layout) > 1) { - auto _0 = i32_val(0); - auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); - auto CTASplitNum = triton::gpu::getCTASplitNum(layout); - auto CTAOrder = triton::gpu::getCTAOrder(layout); - - auto multiDimClusterCTAId = - delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); - - for (unsigned dim = 0; dim < rank; ++dim) { - // Skip when multicast is not enabled in this dimension - if (CTAsPerCGA[dim] == CTASplitNum[dim]) - continue; - // This wrapping rule must be consistent with emitCTAOffsetForLayout - unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); - Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum)); - // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: - // CTA0 and CTA2 holds data of block0, - // CTA1 and CTA3 holds data of block1. - // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should - // be masked. We add the following mask: - // multiDimClusterCTAId[dim] / splitNum == 0 - // Actually in all existing cases of multicast, splitNum is always 1. - // The mask is equivalent to: - // multiDimClusterCTAId[dim] == 0 - mask = and_(mask, icmp_eq(repId, _0)); + auto layout = tensorTy.getEncoding(); + auto ll = triton::gpu::toLinearLayout(shape, layout); + assert(ll.has_value() && "Failed to convert layout to linear layout"); + auto freeVariableMasks = ll->getFreeVariableMasks(); + auto regMasks = freeVariableMasks[kReg]; + if (regMasks & regIdx) { + // Step 1: check register redundancy + mask = false_val(); + } else { + Value warpSize = + i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(moduleOp)); + Value laneId = urem(tid, warpSize); + Value warpId = udiv(tid, warpSize); + // Step 2: check lane and warp redundancy + auto laneMasks = freeVariableMasks[kLane]; + auto warpMasks = freeVariableMasks[kWarp]; + mask = and_(mask, icmp_eq(and_(i32_val(laneMasks), laneId), i32_val(0))); + mask = and_(mask, icmp_eq(and_(i32_val(warpMasks), warpId), i32_val(0))); + if (numCTAs > 1) { + // Step 3: check block redundancy + auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); + auto ctaMasks = freeVariableMasks[kBlock]; + mask = and_(mask, icmp_eq(and_(i32_val(ctaMasks), ctaId), i32_val(0))); } } } else { - // If the tensor is not ranked, then it is a scalar and only thread 0 of - // CTA0 can write - mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0))); mask = and_(mask, icmp_eq(tid, i32_val(0))); + if (numCTAs > 1) { + auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); + // If the tensor is not ranked, then it is a scalar and only thread 0 of + // CTA0 within the cluster can write + mask = and_(mask, icmp_eq(ctaId, i32_val(0))); + } } return mask; } @@ -253,7 +235,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, PTXBuilder ptxBuilder; - Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + Value pred = mask ? maskElems[vecStart] : true_val(); const std::string readConstraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); @@ -426,7 +408,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, << mask << "\n"; } - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + auto moduleOp = op->getParentOfType(); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -474,6 +456,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, PTXBuilder ptxBuilder; auto *asmArgList = ptxBuilder.newListOperand(asmArgs); + Value mask = getRedundantDataMask(moduleOp, valueTy, rewriter, loc, + vecStart, targetInfo); Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask; auto *asmAddr = @@ -566,7 +550,6 @@ struct AtomicCASOpConversion << " origin vec = " << vecOrig << " elemsPerThread = " << elemsPerThread << "\n"; - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -596,6 +579,8 @@ struct AtomicCASOpConversion os << op.getSem(); auto scope = stringifyMemSyncScope(op.getScope()).str(); atom.global().o(semStr).o(scope).o("cas").o(sTy); + Value mask = + getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo); atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); if (tensorTy) { @@ -649,13 +634,11 @@ struct AtomicRMWOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - bool supportsVectorized(Operation *moduleOp, RMWOp opType, - Type elementType) const { + bool supportsVectorized(RMWOp opType, Type elementType) const { // vectorized atomics are only supported on hopper, // and only for specific atomic ops (add, min, max). // Note that "packed types" like f16x2 are supported sm60+. - auto computeCapability = getNVIDIAComputeCapability(moduleOp); - if (computeCapability < 90) { + if (!targetInfo.supportVectorizedAtomics()) { return false; } @@ -707,8 +690,7 @@ struct AtomicRMWOpConversion vecOrig = vec; packed = 1; auto valTy = cast(val.getType()); - if (!supportsVectorized(moduleOp, atomicRmwAttr, - valTy.getElementType())) { + if (!supportsVectorized(atomicRmwAttr, valTy.getElementType())) { packed = std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); vec = 1; @@ -728,12 +710,12 @@ struct AtomicRMWOpConversion << " packed = " << packed << " origin vec = " << vecOrig << " numElems = " << numElems; - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec * packed) { Value rmwPtr = ptrElements[i]; + Value mask = + getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo); Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; @@ -968,6 +950,7 @@ struct AsyncCopyGlobalToLocalOpConversion << vecBytes << " bytes"; } + auto moduleOp = op->getParentOfType(); for (int i = 0; i < shmemAddrs.size(); i++) { // It's possible that vecTy is larger than 128 bits, in which case we have // to use multiple cp.async instructions. @@ -995,24 +978,26 @@ struct AsyncCopyGlobalToLocalOpConversion // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. // XXX(Keren): Always assume other = 0 for now. + // When 'other != 0' is supported, we will need to fold the + // op.getMask() and redundantDataMask() into the same predicate, the + // way it is done for LoadOp. auto selectOp = select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } - // When 'other != 0' is supported, we will need to fold the op.getMask() - // and redundantDataMask() into the same predicate, the way it is done - // for LoadOp. - Value maskVal = redundantDataMask(srcTy, rewriter, loc, targetInfo); - - // TODO: Masking does not work for CTA multicast with cp.async. This is - // a quick and dirty workaround to avoid the issue. bool skipMaskForMultiCTA = triton::gpu::getNumCTAs(srcLayout) > 1; - if (!skipMaskForMultiCTA) { - copyAsyncOp(dstOperand, srcOperand, copySize, srcSize) - .predicate(maskVal); - } else { + if (skipMaskForMultiCTA) { + // TODO: Masking does not work for CTA multicast with cp.async. + // XXX(@peterbell10): In the multi-CTA mode, the redundant data might + // be on different CTAs which don't share the same smem address space, + // so we might need to load the same data multiple times. copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); + } else { + Value mask = getRedundantDataMask(moduleOp, srcTy, rewriter, loc, + elemIdx, targetInfo); + copyAsyncOp(dstOperand, srcOperand, copySize, srcSize) + .predicate(mask); } ptxBuilder.launch(rewriter, loc, void_ty(getContext())); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index c64ba1915ded..459a00c1a142 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -282,6 +282,37 @@ struct ExperimentalTensormapCreateOpConversion } }; +struct ReinterpretTensorDescOpConversion + : public ConvertOpToLLVMPattern { + + ReinterpretTensorDescOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ReinterpretTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getRawDesc()); + return success(); + } +}; + +struct TensorDescToTMAPtrOpConversion + : public ConvertOpToLLVMPattern { + + TensorDescToTMAPtrOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TensorDescToTMAPtrOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getDesc()); + return success(); + } +}; + } // namespace void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( @@ -289,6 +320,8 @@ void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add( - typeConverter, benefit); + patterns + .add( + typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d6537ecb1117..7c4a9e5b92df 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -5,15 +5,12 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "llvm/Support/MathExtras.h" using namespace mlir; using mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; namespace { // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { @@ -93,20 +90,12 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - if (op.getNumOperands() != 1 || op.getNumResults() != 1) - return std::nullopt; - Block *block = &(*op.getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp)) @@ -478,6 +467,46 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, return false; } +// TODO (Keren): Currently, we have more restrictions than necessary when using +// stmatrix. These restrictions are retained from legacy code, and we could +// relax some of them in the future. +// TODO (Lezcano): The proper way of doing this is to directly try to fit the +// relevant layout and return an std::optional. I'm keeping this +// split to keep the current PR smaller +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + if (computeCapability < 90) { + return false; + } + auto mmaLayout = + mlir::dyn_cast(tensorTy.getEncoding()); + if (!mmaLayout || !mmaLayout.isHopper()) + return false; + if (isa(tensorTy.getElementType())) + return false; + if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) + return false; + if (order[0] != 1) + return false; + + auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); + if (tensorShapePerCTA.size() != 2) + return false; + auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * + ceil(tensorShapePerCTA[0], repShape[0]); + if (numIterations > 1) + return false; + if (paddedRepShape[1] % 8 != 0) + return false; + if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && + swizzleByteSize != 128) + return false; + return true; +} + void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { auto vals = unpackLLVector(loc, val, rewriter); @@ -583,4 +612,8 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, int TargetInfo::getSharedAddressSpace() const { return 3; } +bool TargetInfo::supportVectorizedAtomics() const { + return computeCapability >= 90 && ptxVersion >= 81; +} + } // namespace mlir::triton::NVIDIA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 7a1b909cc49e..eedab90c98e3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -7,7 +7,8 @@ namespace mlir::triton::NVIDIA { class TargetInfo : public mlir::triton::TargetInfoBase { public: - TargetInfo(int computeCapability) : computeCapability(computeCapability) {} + TargetInfo(int computeCapability, int ptxVersion) + : computeCapability(computeCapability), ptxVersion(ptxVersion) {} bool supportMaximumMinimum() const override; @@ -22,6 +23,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Type elemTy, Value pred) const override; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const override; @@ -53,8 +60,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef file, StringRef func, int line) const override; int getSharedAddressSpace() const override; + bool supportVectorizedAtomics() const override; + private: int computeCapability; + int ptxVersion; }; } // namespace mlir::triton::NVIDIA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 21f5b706320d..d749d44bc498 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -7,7 +7,6 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" @@ -45,7 +44,6 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { public: explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalOp(); @@ -79,13 +77,16 @@ struct ConvertTritonGPUToLLVM ConvertTritonGPUToLLVM(int32_t computeCapability) : ConvertTritonGPUToLLVMBase({computeCapability}) {} + ConvertTritonGPUToLLVM(int32_t computeCapability, int32_t ptxVersion) + : ConvertTritonGPUToLLVMBase({computeCapability, ptxVersion}) {} + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TargetInfo targetInfo(computeCapability); + TargetInfo targetInfo(computeCapability, ptxVersion); TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMConversionTarget convTarget(*context); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); @@ -142,6 +143,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateGatherOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); populateBarrierOpToLLVMPatterns(typeConverter, patterns, benefit); populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, benefit); populateClusterOpsToLLVMPatterns(typeConverter, patterns, benefit); @@ -227,6 +230,12 @@ std::unique_ptr> createConvertTritonGPUToLLVMPass(int32_t computeCapability) { return std::make_unique(computeCapability); } +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, + int32_t ptxVersion) { + return std::make_unique(computeCapability, + ptxVersion); +} bool NVIDIA::canSkipBarSync(Operation *before, Operation *after) { // Multiple init barriers on the same allocation would usually not happen but diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 9404bb4474d0..47c7fcc063f9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -1,10 +1,12 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" @@ -12,13 +14,79 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; +// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed +// into 4 32bits regs. +static constexpr const char *ptxAsm = + "{\n" + ".reg .b32 a<14>;\n" + "and.b32 a0, $4, -2004318072;\n\t" + "shr.u32 a1, a0, 3;\n\t" + "and.b32 a2, $4, 2004318071;\n\t" + "shr.u32 a3, a2, 16;\n\t" + "shr.u32 a4, a0, 19;\n\t" + "prmt.b32 a5, -1065353216, -1065336832, a2;\n\t" + "prmt.b32 a6, -1065353216, -1065336832, a3;\n\t" + "prmt.b32 a7, 1061109504, 1077952576, a2;\n\t" + "prmt.b32 a8, 1061109504, 1077952576, a3;\n\t" + "prmt.b32 a9, 32768, 0, a1;\n\t" + "prmt.b32 a10, 32768, 0, a4;\n\t" + "or.b32 a11, a7, a9;\n\t" + "or.b32 a12, a8, a10;\n\t" + "prmt.b32 $0, a5, a11, 20800;\n\t" + "prmt.b32 $1, a5, a11, 29538;\n\t" + "prmt.b32 $2, a6, a12, 20800;\n\t" + "prmt.b32 $3, a6, a12, 29538;\n\t" + "}"; + +static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter, + Type retType, Value packedVec) { + PTXBuilder builder; + SmallVector operands; + for (int i = 0; i < 4; i++) { + operands.push_back(builder.newOperand("=r")); + } + operands.push_back(builder.newOperand(packedVec, "r")); + auto &ptxOp = *builder.create(ptxAsm); + ptxOp(operands, /*onlyAttachMLIRArgs=*/true); + Value result = builder.launch(rewriter, loc, retType, false); + return result; +} + +static SmallVector convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter, + Location loc, + ArrayRef values) { + SmallVector results; + MLIRContext *ctx = rewriter.getContext(); + assert(values.size() % 4 == 0); + for (int i = 0; i < values.size(); i += 4) { + Value v0 = values[i]; + Value v1 = values[i + 1]; + Value v2 = values[i + 2]; + Value v3 = values[i + 3]; + Value packedVec = undef(vec_ty(i8_ty, 4)); + packedVec = insert_element(packedVec, v0, i32_val(0)); + packedVec = insert_element(packedVec, v1, i32_val(1)); + packedVec = insert_element(packedVec, v2, i32_val(2)); + packedVec = insert_element(packedVec, v3, i32_val(3)); + SmallVector rets(4, i32_ty); + Type retType = struct_ty(rets); + Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec); + for (int i = 0; i < 4; i++) { + Value extractI32 = extract_val(ret, i); + Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2)); + results.push_back(extract_element(vecbf16, i32_val(0))); + results.push_back(extract_element(vecbf16, i32_val(1))); + } + } + return results; +} + namespace { class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: @@ -30,73 +98,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} - llvm::SmallVector - unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter, - const llvm::SmallVector &vals, Value laneId) const { - auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value { - auto em0 = and_(v, i8_val(0x70)); - auto em1 = and_(v, i8_val(0x7)); - Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), - shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); - Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), - shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); - - // Three cases: - // 1) x is normal and non-zero: Correct bias - v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), - add(v0, i16_val((127 - 1) << 7)), v0); - v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), - add(v1, i16_val((127 - 1) << 7)), v1); - - // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in - // bf16 - v0 = select(icmp_eq(em0, i8_val(0x10)), - or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); - v1 = select(icmp_eq(em1, i8_val(0x1)), - or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); - // 3) x is zero, nothing to do - - // Swap as they come packed in big endian - return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16))); - }; - - auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2]( - Value v) -> llvm::SmallVector { - llvm::SmallVector results(4); - for (int i = 0; i < 4; ++i) { - auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); - results[i] = fp4x2ToBf16x2(v_i); - } - return results; - }; - - // Split fp4x8 into 4 bf16x2 - llvm::SmallVector ret; - ret.reserve(vals.size() * 4); - for (int i = 0; i < vals.size(); ++i) { - auto vs = fp4x8ToBf16x2(vals[i]); - assert(vs.size() == 4); - for (auto v : vs) { - ret.push_back(v); - } - } - // FIXME [Dot LL] - // The DotOperandEncodingAttr without LLs encodes the - // layout as - // e0 e1 - // e2 e3 - // rather than transposed that, as the PTX docs say - // We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2) - assert(ret.size() % 16 == 0); - for (int i = 0; i < ret.size() / 16; ++i) { - for (int j = 0; j < 4; ++j) { - std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]); - } - } - - return ret; - } - LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -116,46 +117,43 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == F8F6F4Type::E2M1) { - xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); - } + auto kWidth = + cast(op.getType().getEncoding()).getKWidth(); - auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(s, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = - or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; - }; + if (fpType == ScaleDotElemType::E2M1) + xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals); // Each thread owns elements of 4 mxfp vectors so we need 4 scales - // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + - // 16, c + 17 + // Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2 + // Then, we need elements c and c + 16 for the first two mxfp vectors + // and elements c + 1 and c + 17 for the last two mxfp vectors auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); - std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), + std::array ci = {c, add(c, i32_val(16)), add(c, i32_val(1)), add(c, i32_val(17))}; + // TODO Move this logic to using LinearLayouts + // Each scale in a warp has to be replicated to cover a tile of shape mxk = + // 16x64 This 16x64 tile is split into 4 subtiles of shape 8x32, each of + // which will have to gather a scale and multiply its relevant part of the + // mxfp vector This tile of 8x32 is split in to 8x4 vectors, leaving each + // vector with 1x8 mxfp elements as long as kWidth * 4 <= 32 + assert(kWidth <= 8 && + "NYI for larger kWidth (but we could do it with less shuffles!)"); for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { - // column major as per the DotOperandEncoding(opidx=0) layout - auto si = std::array{ - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), - }; - - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + for (int mxfp = 0; mxfp < 2; ++mxfp) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 1])}; + for (int rep = 0; rep < 8 / kWidth; ++rep) { + for (int subTile = 0; subTile < 2; ++subTile) { + for (int k = 0; k < kWidth; ++k) { + auto idx = + 32 * i + 16 * mxfp + rep * 2 * kWidth + subTile * kWidth + k; + xVals[idx] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile]); + } + } + } } } diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 1269dcda00aa..a4b84877be61 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "cublas_instance.h" @@ -18,9 +18,11 @@ void init_triton_nvidia_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; // TODO: it is weird to pass mlir::triton::NVVM here since the conversion is // nvidia-specificontext - m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability) { - pm.addPass(mlir::triton::createConvertTritonGPUToLLVMPass(capability)); - }); + m.def("add_to_llvmir", + [](mlir::PassManager &pm, int32_t capability, int32_t ptxVersion) { + pm.addPass(mlir::triton::createConvertTritonGPUToLLVMPass( + capability, ptxVersion)); + }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm) { pm.addPass(NVIDIA::createDecomposeUnsupportedConversionsPass()); }); diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index e2d9152c9626..e0fafb43a929 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -19,19 +19,8 @@ include_directories(${JSON_INCLUDE_DIR}) include_directories(${PROTON_SRC_DIR}/include) include_directories(${PROTON_EXTERN_DIR}) -if(PYTHON_INCLUDE_DIRS) - # We have PYTHON_INCLUDE_DIRS set--this is what we expect when building - # using pip install. - include_directories(${PYTHON_INCLUDE_DIRS}) - include_directories(${PYBIND11_INCLUDE_DIR}) -else() - # Otherwise, we might be building from top CMakeLists.txt directly. - # Try to find Python and pybind11 packages. - find_package(Python3 REQUIRED Interpreter Development) - find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") - include_directories(${Python3_INCLUDE_DIRS}) - include_directories(${pybind11_INCLUDE_DIR}) -endif() +find_package(Python3 REQUIRED Interpreter Development.Module) +find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") # Check if the platform is MacOS if(APPLE) @@ -49,4 +38,5 @@ include_directories(${CUPTI_INCLUDE_DIR}) include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR}) target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__) -target_link_libraries(proton PRIVATE ${Python_LIBRARIES} ${PROTON_PYTHON_LDFLAGS}) +target_link_libraries(proton PRIVATE Python3::Module pybind11::headers) +target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) diff --git a/third_party/proton/README.md b/third_party/proton/README.md index 674540b8afd9..9d16ec5dbe8e 100644 --- a/third_party/proton/README.md +++ b/third_party/proton/README.md @@ -128,6 +128,7 @@ The following examples demonstrate how to use Proton command-line. proton [options] script.py [script_args] [script_options] proton [options] pytest [pytest_args] [script_options] python -m triton.profiler.proton [options] script.py [script_args] [script_options] +proton --instrument=[instrumentation pass] script.py ``` When profiling in the command line mode, the `proton.start` and `proton.finalize` functions are automatically called before and after the script execution. Any `proton.start` and `proton.finalize` functions in the script are ignored. Also, in the command line mode, only a single *session* is supported. Therefore, `proton.deactivate(session_id=1)` is invalid, while `proton.deactivate(session_id=0)` is valid. @@ -143,12 +144,68 @@ proton-viewer -m time/s NOTE: `pip install hatchet` does not work because the API is slightly different. +### Visualizing sorted profile data + +In addition visualizing the profile data on terminal through Hatchet. A sorted list of the kernels by the first metric can be done using the --print-sorted flag with proton-viewer + +```bash +proton-viewer -m time/ns,time/% --print-sorted +``` + +prints the sorted kernels by the time/ns since it is the first listed. + More options can be found by running the following command. ```bash proton-viewer -h ``` +## Advanced features + +### State annotation + +In addition to `proton.scope`, we can also customize the call path of each GPU operation using `proton.state`. + +`state` is different from `scope` in several ways: + +1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state. +2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel. +3. State is compatible with both Python and shadow contexts. + +The following example demonstrates a basic use of state: + +```python +with proton.scope("test"): + with proton.state("state0"): + with proton.scope("test0"): + foo0[1,](x, y) + with proton.scope("test1"): + foo1[1,](x, y) +``` + +The call path of `foo1` will be `test->test1->state0`. + +### Instrumentation (experimental) + +In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis +and optimization information. This feature is under active development and the list of available passes is expected to grow. + +#### Available passes + +print-mem-spaces: this pass prints the load and store address spaces (e.g. global, flat, shared) chosen by the compiler and attributes back to Triton source information. + +Example usage with the Proton matmul tutorial: + +```bash +$ proton --instrument=print-mem-spaces matmul.py +0 matmul_kernel matmul.py:180:20 SHARED STORE +1 matmul_kernel matmul.py:181:20 SHARED STORE +2 matmul_kernel matmul.py:180:20 SHARED LOAD +3 matmul_kernel matmul.py:181:20 SHARED LOAD +``` + +Notes: The instrument functionality is currently only available from the command line. Additionally the instrument and profile command line arguments can not be use simulantously. + ### Instruction sampling (experimental) Proton supports instruction sampling on NVIDIA GPUs. @@ -209,3 +266,7 @@ If you encounter permission related problems when using instruction sampling, yo The overhead of instruction sampling on NVIDIA GPUs is about 20x using Proton because we haven't enabled continuous sampling yet. Continuous sampling can allow for more runtime optimizations, but it makes it more challenging to attribute performance data back to the GPU kernels because: (1) it enables profiling of concurrent kernels, (2) it doesn't allow profiling of time and instruction samples simultaneously, and (3) it works best if we have a separate thread dedicated to attributing instruction samples to the GPU kernels + +- Visible devices on AMD GPUs + +Environment variables such as `HIP_VISIBLE_DEVICES`, and `CUDA_VISIBLE_DEVICES` are not supported on AMD GPUs. Once it's set, we cannot find a valid mapping between the device ID returned by RocTracer and the physical device ID. Instead, `ROCR_VISIBLE_DEVICES` is recommended to be used. diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index 1a7f762591ab..7c1e07bf3d9b 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -1,5 +1,4 @@ #include "Proton.h" -#include "Driver/GPU/CudaApi.h" #include #include @@ -27,10 +26,16 @@ void initProton(pybind11::module &&m) { SessionManager::instance().activateSession(sessionId); }); + m.def("activate_all", + []() { SessionManager::instance().activateAllSessions(); }); + m.def("deactivate", [](size_t sessionId) { SessionManager::instance().deactivateSession(sessionId); }); + m.def("deactivate_all", + []() { SessionManager::instance().deactivateAllSessions(); }); + m.def("finalize", [](size_t sessionId, const std::string &outputFormat) { auto outputFormatEnum = parseOutputFormat(outputFormat); SessionManager::instance().finalizeSession(sessionId, outputFormatEnum); @@ -59,6 +64,13 @@ void initProton(pybind11::module &&m) { SessionManager::instance().exitOp(Scope(scopeId, name)); }); + m.def("enter_state", [](const std::string &state) { + SessionManager::instance().setState(state); + }); + + m.def("exit_state", + []() { SessionManager::instance().setState(std::nullopt); }); + m.def("add_metrics", [](size_t scopeId, const std::map &metrics) { diff --git a/third_party/proton/csrc/include/Context/Context.h b/third_party/proton/csrc/include/Context/Context.h index 9b1205f81bd4..4baa357d913c 100644 --- a/third_party/proton/csrc/include/Context/Context.h +++ b/third_party/proton/csrc/include/Context/Context.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -31,7 +32,20 @@ class ContextSource { public: ContextSource() = default; virtual ~ContextSource() = default; - virtual std::vector getContexts() = 0; + + std::vector getContexts() { + auto contexts = getContextsImpl(); + if (state.has_value()) { + contexts.push_back(state.value()); + } + return contexts; + } + + void setState(std::optional state) { ContextSource::state = state; } + +protected: + virtual std::vector getContextsImpl() = 0; + static thread_local std::optional state; }; /// A scope is a context with a unique identifier. diff --git a/third_party/proton/csrc/include/Context/Python.h b/third_party/proton/csrc/include/Context/Python.h index b7878da37d3b..9c34d0f6d1b1 100644 --- a/third_party/proton/csrc/include/Context/Python.h +++ b/third_party/proton/csrc/include/Context/Python.h @@ -8,7 +8,10 @@ namespace proton { /// Unwind the Python stack and early return a list of contexts. class PythonContextSource : public ContextSource { public: - std::vector getContexts() override; + PythonContextSource() = default; + +private: + std::vector getContextsImpl() override; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Context/Shadow.h b/third_party/proton/csrc/include/Context/Shadow.h index b912335280ef..3f7e2da5af12 100644 --- a/third_party/proton/csrc/include/Context/Shadow.h +++ b/third_party/proton/csrc/include/Context/Shadow.h @@ -12,13 +12,12 @@ class ShadowContextSource : public ContextSource, public ScopeInterface { public: ShadowContextSource() = default; - std::vector getContexts() override { return contextStack; } - void enterScope(const Scope &scope) override; void exitScope(const Scope &scope) override; private: + std::vector getContextsImpl() override { return contextStack; } std::vector contextStack; }; diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index d5033b06aa63..efbcab78f71b 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -58,13 +58,15 @@ class GPUProfiler : public Profiler, ThreadState(ConcreteProfilerT &profiler) : profiler(profiler) {} - void record(size_t scopeId) { + size_t record() { + auto scopeId = Scope::getNewScopeId(); if (profiler.isOpInProgress()) - return; + return scopeId; std::set dataSet = profiler.getDataSet(); for (auto data : dataSet) data->addScope(scopeId); profiler.correlation.apiExternIds.insert(scopeId); + return scopeId; } void enterOp(size_t scopeId) { diff --git a/third_party/proton/csrc/include/Profiler/Profiler.h b/third_party/proton/csrc/include/Profiler/Profiler.h index e87e8ccef88f..ed14fc1b685e 100644 --- a/third_party/proton/csrc/include/Profiler/Profiler.h +++ b/third_party/proton/csrc/include/Profiler/Profiler.h @@ -27,10 +27,9 @@ class Profiler { /// If the profiler is already started, this function does nothing. Profiler *start() { std::unique_lock lock(mutex); - if (this->isInitialized) - return this; - this->doStart(); - this->isInitialized = true; + if (this->initializedCount == 0) + this->doStart(); + this->initializedCount++; return this; } @@ -45,10 +44,11 @@ class Profiler { /// Stop the profiler. Profiler *stop() { std::unique_lock lock(mutex); - if (!this->isInitialized) + if (this->initializedCount == 0) return this; - this->doStop(); - this->isInitialized = false; + this->initializedCount--; + if (this->initializedCount == 0) + this->doStop(); return this; } @@ -80,7 +80,9 @@ class Profiler { mutable std::shared_mutex mutex; std::set dataSet; - bool isInitialized{false}; + +private: + int initializedCount{}; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index f5a63d8ea8a4..b800d447da3f 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -75,10 +75,14 @@ class SessionManager : public Singleton { void finalizeAllSessions(OutputFormat outputFormat); - void activateSession(size_t sesssionId); + void activateSession(size_t sessionId); + + void activateAllSessions(); void deactivateSession(size_t sessionId); + void deactivateAllSessions(); + void enterScope(const Scope &scope); void exitScope(const Scope &scope); @@ -91,13 +95,15 @@ class SessionManager : public Singleton { const std::map &metrics, bool aggregable); + void setState(std::optional context); + private: std::unique_ptr makeSession(size_t id, const std::string &path, const std::string &profilerName, const std::string &contextSourceName, const std::string &dataName); - void activateSessionImpl(size_t sesssionId); + void activateSessionImpl(size_t sessionId); void deActivateSessionImpl(size_t sessionId); @@ -135,13 +141,15 @@ class SessionManager : public Singleton { // path -> session id std::map sessionPaths; // session id -> active - std::map activeSessions; + std::map sessionActive; // session id -> session std::map> sessions; // scope -> active count std::map scopeInterfaceCounts; // op -> active count std::map opInterfaceCounts; + // context source -> active count + std::map contextSourceCounts; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Utility/Errors.h b/third_party/proton/csrc/include/Utility/Errors.h index 094723d6f7e8..09c44025dc45 100644 --- a/third_party/proton/csrc/include/Utility/Errors.h +++ b/third_party/proton/csrc/include/Utility/Errors.h @@ -7,7 +7,7 @@ namespace proton { class NotImplemented : public std::logic_error { public: - NotImplemented() : std::logic_error("Not yet implemented"){}; + NotImplemented() : std::logic_error("Not yet implemented") {}; }; } // namespace proton diff --git a/third_party/proton/csrc/lib/Context/Context.cpp b/third_party/proton/csrc/lib/Context/Context.cpp index 676bdd8d6f2f..04e5170d0e10 100644 --- a/third_party/proton/csrc/lib/Context/Context.cpp +++ b/third_party/proton/csrc/lib/Context/Context.cpp @@ -2,6 +2,9 @@ namespace proton { +/*static*/ thread_local std::optional ContextSource::state = + std::nullopt; + std::atomic Scope::scopeIdCounter{1}; /*static*/ thread_local std::map diff --git a/third_party/proton/csrc/lib/Context/Python.cpp b/third_party/proton/csrc/lib/Context/Python.cpp index 1cc26eaf37e4..3dc3aeb64afa 100644 --- a/third_party/proton/csrc/lib/Context/Python.cpp +++ b/third_party/proton/csrc/lib/Context/Python.cpp @@ -71,7 +71,7 @@ std::string unpackPyobject(PyObject *pyObject) { } // namespace -std::vector PythonContextSource::getContexts() { +std::vector PythonContextSource::getContextsImpl() { pybind11::gil_scoped_acquire gil; PyFrameObject *frame = PyEval_GetFrame(); diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp index 19b50214b98c..45082a14ef25 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -15,10 +15,10 @@ namespace { uint64_t getCubinCrc(const char *cubin, size_t size) { CUpti_GetCubinCrcParams cubinCrcParams = { - .size = CUpti_GetCubinCrcParamsSize, - .cubinSize = size, - .cubin = cubin, - .cubinCrc = 0, + /*size=*/CUpti_GetCubinCrcParamsSize, + /*cubinSize=*/size, + /*cubin=*/cubin, + /*cubinCrc=*/0, }; cupti::getCubinCrc(&cubinCrcParams); return cubinCrcParams.cubinCrc; @@ -27,10 +27,10 @@ uint64_t getCubinCrc(const char *cubin, size_t size) { size_t getNumStallReasons(CUcontext context) { size_t numStallReasons = 0; CUpti_PCSamplingGetNumStallReasonsParams numStallReasonsParams = { - .size = CUpti_PCSamplingGetNumStallReasonsParamsSize, - .pPriv = NULL, - .ctx = context, - .numStallReasons = &numStallReasons}; + /*size=*/CUpti_PCSamplingGetNumStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/&numStallReasons}; cupti::pcSamplingGetNumStallReasons(&numStallReasonsParams); return numStallReasons; } @@ -39,14 +39,14 @@ std::tuple getSassToSourceCorrelation(const char *functionName, uint64_t pcOffset, const char *cubin, size_t cubinSize) { CUpti_GetSassToSourceCorrelationParams sassToSourceParams = { - .size = CUpti_GetSassToSourceCorrelationParamsSize, - .cubin = cubin, - .functionName = functionName, - .cubinSize = cubinSize, - .lineNumber = 0, - .pcOffset = pcOffset, - .fileName = NULL, - .dirName = NULL, + /*size=*/CUpti_GetSassToSourceCorrelationParamsSize, + /*cubin=*/cubin, + /*functionName=*/functionName, + /*cubinSize=*/cubinSize, + /*lineNumber=*/0, + /*pcOffset=*/pcOffset, + /*fileName=*/NULL, + /*dirName=*/NULL, }; // Get source can fail if the line mapping is not available in the cubin so we // don't check the return value @@ -77,12 +77,12 @@ getStallReasonNamesAndIndices(CUcontext context, size_t numStallReasons) { static_cast(std::calloc(numStallReasons, sizeof(uint32_t))); // Initialize the names with 128 characters to avoid buffer overflow CUpti_PCSamplingGetStallReasonsParams stallReasonsParams = { - .size = CUpti_PCSamplingGetStallReasonsParamsSize, - .pPriv = NULL, - .ctx = context, - .numStallReasons = numStallReasons, - .stallReasonIndex = stallReasonIndices, - .stallReasons = stallReasonNames, + /*size=*/CUpti_PCSamplingGetStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/numStallReasons, + /*stallReasonIndex=*/stallReasonIndices, + /*stallReasons=*/stallReasonNames, }; cupti::pcSamplingGetStallReasons(&stallReasonsParams); return std::make_pair(stallReasonNames, stallReasonIndices); @@ -143,9 +143,15 @@ CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) pcDataSize -= CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE; CUpti_PCSamplingData pcSamplingData{ - .size = pcDataSize, - .collectNumPcs = collectNumPCs, - .pPcData = static_cast( + /*size=*/pcDataSize, + /*collectNumPcs=*/collectNumPCs, + /*totalSamples=*/0, + /*droppedSamples=*/0, + /*totalNumPcs=*/0, + /*remainingNumPcs=*/0, + /*rangeId=*/0, + /*pPcData=*/ + static_cast( std::calloc(collectNumPCs, sizeof(CUpti_PCSamplingPCData)))}; for (size_t i = 0; i < collectNumPCs; ++i) { pcSamplingData.pPcData[i].stallReason = @@ -157,36 +163,36 @@ CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, void enablePCSampling(CUcontext context) { CUpti_PCSamplingEnableParams params = { - .size = CUpti_PCSamplingEnableParamsSize, - .pPriv = NULL, - .ctx = context, + /*size=*/CUpti_PCSamplingEnableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, }; cupti::pcSamplingEnable(¶ms); } void disablePCSampling(CUcontext context) { CUpti_PCSamplingDisableParams params = { - .size = CUpti_PCSamplingDisableParamsSize, - .pPriv = NULL, - .ctx = context, + /*size=*/CUpti_PCSamplingDisableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, }; cupti::pcSamplingDisable(¶ms); } void startPCSampling(CUcontext context) { CUpti_PCSamplingStartParams params = { - .size = CUpti_PCSamplingStartParamsSize, - .pPriv = NULL, - .ctx = context, + /*size=*/CUpti_PCSamplingStartParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, }; cupti::pcSamplingStart(¶ms); } void stopPCSampling(CUcontext context) { CUpti_PCSamplingStopParams params = { - .size = CUpti_PCSamplingStopParamsSize, - .pPriv = NULL, - .ctx = context, + /*size=*/CUpti_PCSamplingStopParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, }; cupti::pcSamplingStop(¶ms); } @@ -194,10 +200,10 @@ void stopPCSampling(CUcontext context) { void getPCSamplingData(CUcontext context, CUpti_PCSamplingData *pcSamplingData) { CUpti_PCSamplingGetDataParams params = { - .size = CUpti_PCSamplingGetDataParamsSize, - .pPriv = NULL, - .ctx = context, - .pcSamplingData = pcSamplingData, + /*size=*/CUpti_PCSamplingGetDataParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*pcSamplingData=*/pcSamplingData, }; cupti::pcSamplingGetData(¶ms); } @@ -206,11 +212,11 @@ void setConfigurationAttribute( CUcontext context, std::vector &configurationInfos) { CUpti_PCSamplingConfigurationInfoParams infoParams = { - .size = CUpti_PCSamplingConfigurationInfoParamsSize, - .pPriv = NULL, - .ctx = context, - .numAttributes = configurationInfos.size(), - .pPCSamplingConfigurationInfo = configurationInfos.data(), + /*size=*/CUpti_PCSamplingConfigurationInfoParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numAttributes=*/configurationInfos.size(), + /*pPCSamplingConfigurationInfo=*/configurationInfos.data(), }; cupti::pcSamplingSetConfigurationAttribute(&infoParams); } diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 9ddbd7a71547..fa0ad0bfda5a 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -323,8 +323,7 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, static_cast(cbData); auto *pImpl = dynamic_cast(profiler.pImpl.get()); if (callbackData->callbackSite == CUPTI_API_ENTER) { - auto scopeId = Scope::getNewScopeId(); - threadState.record(scopeId); + auto scopeId = threadState.record(); threadState.enterOp(scopeId); size_t numInstances = 1; if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 68f3f0beac9f..ca93678e1c82 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -74,6 +74,7 @@ std::shared_ptr convertActivityToMetric(const roctracer_record_t *activity) { std::shared_ptr metric; switch (activity->kind) { + case kHipVdiCommandTask: case kHipVdiCommandKernel: { if (activity->begin_ns < activity->end_ns) { metric = std::make_shared( @@ -135,7 +136,7 @@ void processActivity(RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, const roctracer_record_t *record, bool isAPI, bool isGraph) { switch (record->kind) { - case 0x11F1: // Task - kernel enqueued by graph launch + case kHipVdiCommandTask: case kHipVdiCommandKernel: { processActivityKernel(corrIdToExternId, externId, dataSet, record, isAPI, isGraph); @@ -169,6 +170,7 @@ std::pair matchKernelCbId(uint32_t cbId) { case HIP_API_ID_hipModuleLaunchCooperativeKernel: case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: case HIP_API_ID_hipGraphExecDestroy: + case HIP_API_ID_hipGraphInstantiateWithFlags: case HIP_API_ID_hipGraphInstantiate: { isRuntimeApi = true; break; @@ -231,8 +233,7 @@ void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( const hip_api_data_t *data = (const hip_api_data_t *)(callbackData); if (data->phase == ACTIVITY_API_PHASE_ENTER) { // Valid context and outermost level of the kernel launch - auto scopeId = Scope::getNewScopeId(); - threadState.record(scopeId); + auto scopeId = threadState.record(); threadState.enterOp(scopeId); size_t numInstances = 1; if (cid == HIP_API_ID_hipGraphLaunch) { @@ -301,6 +302,13 @@ void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( pImpl->StreamToCaptureCount[Stream]++; break; } + case HIP_API_ID_hipGraphInstantiateWithFlags: { + hipGraph_t Graph = data->args.hipGraphInstantiateWithFlags.graph; + hipGraphExec_t GraphExec = + *(data->args.hipGraphInstantiateWithFlags.pGraphExec); + pImpl->GraphExecToGraph[GraphExec] = Graph; + break; + } case HIP_API_ID_hipGraphInstantiate: { hipGraph_t Graph = data->args.hipGraphInstantiate.graph; hipGraphExec_t GraphExec = *(data->args.hipGraphInstantiate.pGraphExec); diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 5ff74f0fc68f..269eb46209c5 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -84,30 +84,46 @@ void SessionManager::activateSession(size_t sessionId) { activateSessionImpl(sessionId); } +void SessionManager::activateAllSessions() { + std::unique_lock lock(mutex); + for (auto iter : sessionActive) { + activateSessionImpl(iter.first); + } +} + void SessionManager::deactivateSession(size_t sessionId) { std::unique_lock lock(mutex); deActivateSessionImpl(sessionId); } +void SessionManager::deactivateAllSessions() { + std::unique_lock lock(mutex); + for (auto iter : sessionActive) { + deActivateSessionImpl(iter.first); + } +} + void SessionManager::activateSessionImpl(size_t sessionId) { throwIfSessionNotInitialized(sessions, sessionId); - if (activeSessions[sessionId]) + if (sessionActive[sessionId]) return; - activeSessions[sessionId] = true; + sessionActive[sessionId] = true; sessions[sessionId]->activate(); registerInterface(sessionId, scopeInterfaceCounts); registerInterface(sessionId, opInterfaceCounts); + registerInterface(sessionId, contextSourceCounts); } void SessionManager::deActivateSessionImpl(size_t sessionId) { throwIfSessionNotInitialized(sessions, sessionId); - if (!activeSessions[sessionId]) { + if (!sessionActive[sessionId]) { return; } - activeSessions[sessionId] = false; + sessionActive[sessionId] = false; sessions[sessionId]->deactivate(); unregisterInterface(sessionId, scopeInterfaceCounts); unregisterInterface(sessionId, opInterfaceCounts); + unregisterInterface(sessionId, contextSourceCounts); } void SessionManager::removeSession(size_t sessionId) { @@ -116,6 +132,7 @@ void SessionManager::removeSession(size_t sessionId) { } auto path = sessions[sessionId]->path; sessionPaths.erase(path); + sessionActive.erase(sessionId); sessions.erase(sessionId); } @@ -204,11 +221,21 @@ void SessionManager::addMetrics( size_t scopeId, const std::map &metrics, bool aggregable) { std::shared_lock lock(mutex); - for (auto [sessionId, active] : activeSessions) { + for (auto [sessionId, active] : sessionActive) { if (active) { sessions[sessionId]->data->addMetrics(scopeId, metrics, aggregable); } } } +void SessionManager::setState(std::optional context) { + std::shared_lock lock(mutex); + for (auto iter : contextSourceCounts) { + auto [contextSource, count] = iter; + if (count > 0) { + contextSource->setState(context); + } + } +} + } // namespace proton diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt new file mode 100644 index 000000000000..cfa5938873d9 --- /dev/null +++ b/third_party/proton/dialect/CMakeLists.txt @@ -0,0 +1,8 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc) + target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers) +endif() diff --git a/third_party/proton/dialect/include/CMakeLists.txt b/third_party/proton/dialect/include/CMakeLists.txt new file mode 100644 index 000000000000..0ca0f41c5af4 --- /dev/null +++ b/third_party/proton/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/include/Dialect/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000000..f18c30ba1a6d --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000000..4645b0ebcd5a --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 000000000000..680a205f08f1 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_ +#define TRITON_DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace proton {} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 000000000000..d469fbb35f6b --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,12 @@ +#ifndef PROTON_ATTRDEFS +#define PROTON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "ProtonDialect.td" + +class Proton_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif // PROTON_ATTRDEFS diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 000000000000..245f2e09a2ec --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,18 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 000000000000..d18a48d5d1a0 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,65 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "ProtonDialect.td" +include "ProtonAttrDefs.td" + +class TT_Proton_Op traits = []> : + Op { +} + +// Proton profiling metric. +def MetricAttr : I32EnumAttr< + "Metric", "", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +// Proton profiling granularity. +def GranularityAttr : I32EnumAttr< + "Granularity", "", + [ + I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">, + I32EnumAttrCase<"WARP", 1, "warp">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods]> { + let summary = "Record a GPU hardware event"; + + let description = [{ + The operator records GPU events from performance counters. + Currently only cycle counter is supported. + + Example: + + ```mlir + proton.record() {isStart = true, regionId = 4 : i32} + ... + proton.record() {isStart = false, regionId = 4 : i32} + ... + proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32} + ... + proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32} + ``` + }]; + let arguments = ( + ins BoolAttr: $isStart, + ConfinedAttr:$regionId, + DefaultValuedAttr:$metric, + DefaultValuedAttr:$granularity + ); + let assemblyFormat = " `(` operands `)` attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt new file mode 100644 index 000000000000..0ca0f41c5af4 --- /dev/null +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/lib/Dialect/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000000..f18c30ba1a6d --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000000..5eea5cb3cf9e --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 000000000000..60c2852654db --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,25 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/Proton/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::proton; + +void mlir::triton::proton::ProtonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 000000000000..1a0799aea127 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,33 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { + +// -- RecordOp -- +void RecordOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc new file mode 100644 index 000000000000..8046539794e1 --- /dev/null +++ b/third_party/proton/dialect/triton_proton.cc @@ -0,0 +1,20 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_proton(py::module &&m) { + auto passes = m.def_submodule("passes"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +} diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 0add689155c3..ded8b01142af 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa from .scope import scope, enter_scope, exit_scope +from .state import state, enter_state, exit_state from .profile import ( start, activate, diff --git a/third_party/proton/proton/hook.py b/third_party/proton/proton/hook.py index 94f94bee0b56..e40e1b38c012 100644 --- a/third_party/proton/proton/hook.py +++ b/third_party/proton/proton/hook.py @@ -1,3 +1,4 @@ +from .state import enter_state, exit_state from .scope import enter_scope, exit_scope from triton.compiler import CompiledKernel, LazyDict @@ -10,9 +11,9 @@ class TritonHook: @staticmethod def enter(lazy_dict: LazyDict) -> None: - enter_scope(COMPUTE_METADATA_SCOPE_NAME) + enter_state(COMPUTE_METADATA_SCOPE_NAME) metadata = lazy_dict.get() - exit_scope() + exit_state() fn_metrics = {k: metadata[k] for k in TritonHook.metrics if k in metadata} enter_scope(metadata["name"], triton_op=True, metrics=fn_metrics) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 2dd7a6f53ed8..575c85b0cac8 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -1,5 +1,6 @@ import functools import triton +import os from triton._C.libproton import proton as libproton from .hook import register_triton_hook, unregister_triton_hook @@ -19,6 +20,16 @@ def _select_backend() -> str: raise ValueError("No backend is available for the current target.") +def _check_env(backend: str) -> None: + if backend == "roctracer": + hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] + for env in hip_device_envs: + if os.getenv(env, None) is not None: + raise ValueError( + f"Proton does not work when the environment variable {env} is set on AMD GPUs. Please unset it and use `ROCR_VISIBLE_DEVICES` instead" + ) + + def start( name: Optional[str] = None, *, @@ -66,42 +77,50 @@ def start( if backend is None: backend = _select_backend() + _check_env(backend) + set_profiling_on() if hook and hook == "triton": register_triton_hook() return libproton.start(name, context, data, backend) -def activate(session: Optional[int] = 0) -> None: +def activate(session: Optional[int] = None) -> None: """ Activate the specified session. The profiling session will be active and data will be recorded. Args: - session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + session (int): The session ID of the profiling session. Defaults to None (all sessions) Returns: None """ if is_command_line() and session != 0: raise ValueError("Only one session can be activated when running from the command line.") - libproton.activate(session) + if session is None: + libproton.activate_all() + else: + libproton.activate(session) -def deactivate(session: Optional[int] = 0) -> None: +def deactivate(session: Optional[int] = None) -> None: """ Stop the specified session. The profiling session's data will still be in the memory, but no more data will be recorded. Args: - session (int): The session ID of the profiling session. Defaults to 0 (the first session started.) + session (int): The session ID of the profiling session. Defaults to None (all sessions) Returns: None """ if is_command_line() and session != 0: raise ValueError("Only one session can be deactivated when running from the command line.") - libproton.deactivate(session) + if session is None: + libproton.deactivate_all() + else: + libproton.deactivate(session) def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None: diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index cbb7a0b6f90d..0eacc850ed66 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -1,8 +1,10 @@ import argparse import sys import os +import pathlib from .profile import start, finalize, _select_backend from .flags import set_command_line +import triton def parse_arguments(): @@ -19,6 +21,8 @@ def parse_arguments(): choices=["shadow", "python"]) parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"]) parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) + parser.add_argument("-i", "--instrument", type=str, help="Instrumentation analysis type", default=None, + choices=[None, "print-mem-spaces"]) parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') args = parser.parse_args() return args, args.target_args @@ -28,7 +32,7 @@ def is_pytest(script): return os.path.basename(script) == 'pytest' -def execute_as_main(script, args): +def execute_as_main(script, args, instrumentation_pass=None): script_path = os.path.abspath(script) # Prepare a clean global environment clean_globals = { @@ -42,6 +46,14 @@ def execute_as_main(script, args): sys.argv = [script] + args # Append the script's directory in case the script uses relative imports sys.path.append(os.path.dirname(script_path)) + top_level_triton_path = os.path.dirname(triton.__file__) + + if instrumentation_pass == "print-mem-spaces": + instrumentation_pass_path = str( + next(pathlib.Path(top_level_triton_path).rglob("libPrintLoadStoreMemSpaces.so"), None)) + os.environ['TRITON_ALWAYS_COMPILE'] = "1" + os.environ['TRITON_DISABLE_LINE_INFO'] = "0" + os.environ['LLVM_PASS_PLUGIN_PATH'] = instrumentation_pass_path # Execute in the isolated environment try: @@ -54,11 +66,7 @@ def execute_as_main(script, args): sys.argv = original_argv -def run_profiling(args, target_args): - backend = args.backend if args.backend else _select_backend() - - start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook) - +def do_setup_and_execute(target_args, instrumentation_pass=None): # Set the command line mode to avoid any `start` calls in the script. set_command_line() @@ -68,13 +76,28 @@ def run_profiling(args, target_args): import pytest pytest.main(script_args) else: - execute_as_main(script, script_args) + execute_as_main(script, script_args, instrumentation_pass) + + +def run_profiling(args, target_args): + backend = args.backend if args.backend else _select_backend() + + start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook) + + do_setup_and_execute(target_args) finalize() +def run_instrumentation(args, target_args): + do_setup_and_execute(target_args, args.instrument) + + def main(): args, target_args = parse_arguments() + if args.instrument: + run_instrumentation(args, target_args) + return run_profiling(args, target_args) diff --git a/third_party/proton/proton/state.py b/third_party/proton/proton/state.py new file mode 100644 index 000000000000..dd1e47801fdb --- /dev/null +++ b/third_party/proton/proton/state.py @@ -0,0 +1,61 @@ +from triton._C.libproton import proton as libproton +from .flags import get_profiling_on +from functools import wraps + + +class state: + """ + A context manager and decorator for entering and exiting a state. + + Usage: + context manager: + ```python + with proton.state("test0"): + foo[1,](x, y) + ``` + + decorator: + ```python + @proton.state("test0") + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the state. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + if not get_profiling_on(): + return self + libproton.enter_state(self.name) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if not get_profiling_on(): + return + libproton.exit_state() + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + if get_profiling_on(): + libproton.enter_state(self.name) + ret = func(*args, **kwargs) + if get_profiling_on(): + libproton.exit_state() + return ret + + return wrapper + + +def enter_state(name: str) -> None: + libproton.enter_state(name) + + +def exit_state() -> None: + libproton.exit_state() diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index fe7c98807c57..6b4fe8d91d4c 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -28,8 +28,37 @@ def match_available_metrics(metrics, raw_metrics): return ret +def remove_metadata(database: json): + # Find all frames with the name COMPUTE_METADATA_SCOPE_NAME, remove them and their children + # Then go up from the metadata node and remove the parent if all its children were + # metadata nodes + def remove_metadata_helper(node): + if "frame" not in node: + return node + if node["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME: + return None + children = node.get("children", []) + new_children = [] + for child in children: + new_child = remove_metadata_helper(child) + if new_child is not None: + new_children.append(new_child) + if len(new_children) > 0 or len(children) == 0: + node["children"] = new_children + return node + return None + + new_database = [] + for node in database: + new_node = remove_metadata_helper(node) + if new_node is not None: + new_database.append(new_node) + return new_database + + def get_raw_metrics(file): database = json.load(file) + database = remove_metadata(database) device_info = database.pop(1) gf = ht.GraphFrame.from_literal(database) return gf, gf.show_metric_columns(), device_info @@ -92,7 +121,7 @@ def get_min_time_bytes(df, device_info): } # FLOPS have a specific width to their metric -default_flop_factor_dict = {f"flop/s": 1, f"gflop/s": 1e9, f"tflop/s": 1e12} +default_flop_factor_dict = {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12} derivable_metrics.update( {key: FactorDict("flops", default_flop_factor_dict) for key in default_flop_factor_dict.keys()}) @@ -180,16 +209,13 @@ def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): """ query = NegationQuery(inclusion_query) gf = gf.filter(query, squash=True) - # filter out metadata computation - query = [{"name": f"^(?!{COMPUTE_METADATA_SCOPE_NAME}).*"}] - gf = gf.filter(query, squash=True) if threshold: query = ["*", {metric: f">= {threshold}"}] gf = gf.filter(query, squash=True) return gf -def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None): +def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None, print_sorted=False): with open(filename, "r") as f: gf, raw_metrics, device_info = get_raw_metrics(f) gf = format_frames(gf, format) @@ -199,6 +225,15 @@ def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=1 # TODO: generalize to support multiple metrics, not just the first one gf = filter_frames(gf, include, exclude, threshold, metrics[0]) print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + if print_sorted: + print("Sorted kernels by metric " + metrics[0].strip("(inc)")) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + for row in range(1, len(sorted_df)): + if len(sorted_df.iloc[row]['name']) > 100: + kernel_name = sorted_df.iloc[row]['name'][:100] + "..." + else: + kernel_name = sorted_df.iloc[row]['name'] + print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]])) emit_warnings(gf, metrics) @@ -269,7 +304,7 @@ def main(): type=str, default=None, help="""Exclude frames that match the given regular expression and their children. -For example, the following command will exclude all paths that contain frames that contains "test": +For example, the following command will exclude all paths starting from frames that contains "test": ``` proton-viewer -e ".*test.*" path/to/file.json ``` @@ -298,6 +333,12 @@ def main(): - function_line: include the function name and line number. - file_function: include the file name and function name. """) + argparser.add_argument( + "--print-sorted", + action='store_true', + default=False, + help="Sort output by metric value instead of chronologically", + ) args, target_args = argparser.parse_known_args() assert len(target_args) == 1, "Must specify a file to read" @@ -309,12 +350,13 @@ def main(): threshold = args.threshold depth = args.depth format = args.format + print_sorted = args.print_sorted if include and exclude: raise ValueError("Cannot specify both include and exclude") if args.list: show_metrics(file_name) elif metrics: - parse(metrics, file_name, include, exclude, threshold, depth, format) + parse(metrics, file_name, include, exclude, threshold, depth, format, print_sorted) if __name__ == "__main__": diff --git a/third_party/proton/test/example_cuda.json b/third_party/proton/test/examples/cuda.json similarity index 99% rename from third_party/proton/test/example_cuda.json rename to third_party/proton/test/examples/cuda.json index 445f0e224c65..5c742267acab 100644 --- a/third_party/proton/test/example_cuda.json +++ b/third_party/proton/test/examples/cuda.json @@ -1,4 +1,4 @@ - [ +[ { "children": [ { diff --git a/third_party/proton/test/example_frame.json b/third_party/proton/test/examples/frame.json similarity index 100% rename from third_party/proton/test/example_frame.json rename to third_party/proton/test/examples/frame.json diff --git a/third_party/proton/test/example_hip.json b/third_party/proton/test/examples/hip.json similarity index 99% rename from third_party/proton/test/example_hip.json rename to third_party/proton/test/examples/hip.json index 68538706cfe9..70eaf325d35b 100644 --- a/third_party/proton/test/example_hip.json +++ b/third_party/proton/test/examples/hip.json @@ -1,4 +1,4 @@ - [ +[ { "children": [ { diff --git a/third_party/proton/test/examples/leaf_nodes.json b/third_party/proton/test/examples/leaf_nodes.json new file mode 100644 index 000000000000..5930664dd244 --- /dev/null +++ b/third_party/proton/test/examples/leaf_nodes.json @@ -0,0 +1,168 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_2_2", + "type": "function" + }, + "metrics": { + "count": 402, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 78190414 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_3_1", + "type": "function" + }, + "metrics": { + "count": 502, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 24125138 + } + } + ], + "frame": { + "name": "kernel_1_2_1", + "type": "function" + }, + "metrics": { + "bytes": 3997237248, + "flops": 1534939103232 + } + } + ], + "frame": { + "name": "kernel_1_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_2_2", + "type": "function" + }, + "metrics": { + "count": 120, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 23174888 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_3_1", + "type": "function" + }, + "metrics": { + "count": 149, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 1040322 + } + } + ], + "frame": { + "name": "kernel_2_2_1", + "type": "function" + }, + "metrics": { + "bytes": 58589184, + "flops": 4999610368 + } + } + ], + "frame": { + "name": "kernel_2_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_2", + "type": "function" + }, + "metrics": { + "count": 480, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 93036508 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "count": 599, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 6306402 + } + } + ], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "bytes": 529956864, + "flops": 67834478592 + } + } + ], + "frame": { + "name": "kernel_3_1_1", + "type": "function" + }, + "metrics": {} + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "flops": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/proton/test/examples/triton.json b/third_party/proton/test/examples/triton.json new file mode 100644 index 000000000000..b870bd70f0c3 --- /dev/null +++ b/third_party/proton/test/examples/triton.json @@ -0,0 +1,71 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "cuda_kernel", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 4064 + } + } + ], + "frame": { + "name": "__proton_launch_metadata", + "type": "function" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "triton_kernel", + "type": "function" + }, + "metrics": { + "bytes": 2.0, + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 1664 + } + } + ], + "frame": { + "name": "scope", + "type": "function" + }, + "metrics": {} + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "time (ns)": 0 + } + }, + { + "CUDA": { + "0": { + "arch": "86", + "bus_width": 128, + "clock_rate": 1140000, + "memory_clock_rate": 5501000, + "num_sms": 16 + } + } + } +] diff --git a/third_party/proton/test/instrument.py b/third_party/proton/test/instrument.py new file mode 100644 index 000000000000..59ebe86a115f --- /dev/null +++ b/third_party/proton/test/instrument.py @@ -0,0 +1,68 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b, activation=""): + # Check constraints. + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + # 1D launch kernel where each block gets its own program. + def grid(): + return (1, ) + + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + 128, 256, 64, 8) + return c + + +a = torch.randn((32, 32), device="cuda", dtype=torch.float16) +b = torch.randn((32, 32), device="cuda", dtype=torch.float16) +matmul(a, b) diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 713572c4fcac..4ced1e35c599 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -1,23 +1,24 @@ import json import triton.profiler as proton -import tempfile import pathlib -def test_profile(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0]) - proton.activate() - proton.deactivate() - proton.finalize() - assert session_id0 == 0 +def test_profile_single_session(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert session_id0 == 0 + assert temp_file0.exists() - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id1 = proton.start(f.name.split(".")[0]) - proton.activate(session_id1) - proton.deactivate(session_id1) - proton.finalize(session_id1) - assert session_id1 == session_id0 + 1 + temp_file1 = tmp_path / "test_profile1.hatchet" + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + proton.activate(session_id1) + proton.deactivate(session_id1) + proton.finalize(session_id1) + assert session_id1 == session_id0 + 1 + assert temp_file1.exists() session_id2 = proton.start("test") proton.activate(session_id2) @@ -28,19 +29,38 @@ def test_profile(): pathlib.Path("test.hatchet").unlink() -def test_profile_decorator(): - f = tempfile.NamedTemporaryFile(delete=True) - name = f.name.split(".")[0] +def test_profile_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + proton.start(str(temp_file0.with_suffix(""))) + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert temp_file0.exists() + assert temp_file1.exists() + + temp_file2 = tmp_path / "test_profile2.hatchet" + session_id2 = proton.start(str(temp_file2.with_suffix(""))) + temp_file3 = tmp_path / "test_profile3.hatchet" + session_id3 = proton.start(str(temp_file3.with_suffix(""))) + proton.deactivate(session_id2) + proton.deactivate(session_id3) + proton.finalize() + assert temp_file2.exists() + assert temp_file3.exists() + + +def test_profile_decorator(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_profile_decorator.hatchet" - @proton.profile(name=name) + @proton.profile(name=str(temp_file.with_suffix(""))) def foo0(a, b): return a + b foo0(1, 2) proton.finalize() - assert pathlib.Path(f.name).exists() - - f.close() + assert temp_file.exists() @proton.profile def foo1(a, b): @@ -48,126 +68,156 @@ def foo1(a, b): foo1(1, 2) proton.finalize() - assert pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet").exists() + default_file = pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet") + assert default_file.exists() + default_file.unlink() -def test_scope(): +def test_scope(tmp_path: pathlib.Path): # Scope can be annotated even when profiling is off with proton.scope("test"): pass - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test"): - pass + temp_file = tmp_path / "test_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test"): + pass - @proton.scope("test") - def foo(): - pass + @proton.scope("test") + def foo(): + pass - foo() + foo() - proton.enter_scope("test") - proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.enter_scope("test") + proton.exit_scope() + proton.finalize() + assert temp_file.exists() -def test_hook(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id0 = proton.start(f.name.split(".")[0], hook="triton") - proton.activate(session_id0) - proton.deactivate(session_id0) - proton.finalize(None) - assert pathlib.Path(f.name).exists() +def test_hook(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_hook.hatchet" + session_id0 = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.activate(session_id0) + proton.deactivate(session_id0) + proton.finalize(None) + assert temp_file.exists() -def test_scope_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - # Test different scope creation methods - with proton.scope("test0", {"a": 1.0}): - pass +def test_scope_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + with proton.scope("test0", {"a": 1.0}): + pass - @proton.scope("test1", {"a": 1.0}) - def foo(): - pass + @proton.scope("test1", {"a": 1.0}) + def foo(): + pass - foo() + foo() - # After deactivation, the metrics should be ignored - proton.deactivate(session_id) - proton.enter_scope("test2", metrics={"a": 1.0}) - proton.exit_scope() + # After deactivation, the metrics should be ignored + proton.deactivate(session_id) + proton.enter_scope("test2", metrics={"a": 1.0}) + proton.exit_scope() - # Metrics should be recorded again after reactivation - proton.activate(session_id) - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + # Metrics should be recorded again after reactivation + proton.activate(session_id) + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test3", metrics={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 3 - for child in data[0]["children"]: - if child["frame"]["name"] == "test3": - assert child["metrics"]["a"] == 2.0 - - -def test_scope_properties(): - with open("test.hatchet", "w+") as f: - proton.start(f.name.split(".")[0]) - # Test different scope creation methods - # Different from metrics, properties could be str - with proton.scope("test0", properties={"a": "1"}): - pass + assert len(data[0]["children"]) == 3 + for child in data[0]["children"]: + if child["frame"]["name"] == "test3": + assert child["metrics"]["a"] == 2.0 + + +def test_scope_properties(tmp_path: pathlib.Path): + temp_file = tmp_path / "test.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + # Different from metrics, properties could be str + with proton.scope("test0", properties={"a": "1"}): + pass - @proton.scope("test1", properties={"a": "1"}) - def foo(): - pass + @proton.scope("test1", properties={"a": "1"}) + def foo(): + pass - foo() + foo() - # Properties do not aggregate - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + # Properties do not aggregate + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.enter_scope("test2", properties={"a": 1.0}) - proton.exit_scope() + proton.enter_scope("test2", properties={"a": 1.0}) + proton.exit_scope() - proton.finalize() - assert pathlib.Path(f.name).exists() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: data = json.load(f) - for child in data[0]["children"]: - if child["frame"]["name"] == "test2": - assert child["metrics"]["a"] == 1.0 - elif child["frame"]["name"] == "test0": - assert child["metrics"]["a"] == "1" - - -def test_throw(): + for child in data[0]["children"]: + if child["frame"]["name"] == "test2": + assert child["metrics"]["a"] == 1.0 + elif child["frame"]["name"] == "test0": + assert child["metrics"]["a"] == "1" + + +def test_state(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_state.hatchet" + proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.enter_state("state") + proton.enter_scope("test1", metrics={"a": 1.0}) + proton.exit_scope() + proton.exit_state() + proton.exit_scope() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + # test0->test1->state + assert len(data[0]["children"]) == 1 + child = data[0]["children"][0] + assert child["frame"]["name"] == "test0" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "test1" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "state" + assert child["metrics"]["a"] == 1.0 + + +def test_throw(tmp_path: pathlib.Path): # Catch an exception thrown by c++ session_id = 100 - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - activate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.activate(session_id + 1) - except Exception as e: - activate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in activate_error - - deactivate_error = "" - try: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id + 1) - except Exception as e: - deactivate_error = str(e) - finally: - proton.finalize() - assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error + temp_file = tmp_path / "test_throw.hatchet" + activate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index fa3331c02405..620dcd569115 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -1,7 +1,8 @@ +import triton import pytest import subprocess -import tempfile import json +import pathlib def test_help(): @@ -10,22 +11,55 @@ def test_help(): assert ret == 0 +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + @pytest.mark.parametrize("mode", ["script", "python", "pytest"]) -def test_exec(mode): +def test_exec(mode, tmp_path: pathlib.Path): file_path = __file__ helper_file = file_path.replace("test_cmd.py", "helper.py") - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - name = f.name.split(".")[0] - if mode == "script": - ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) - elif mode == "python": - ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], - stdout=subprocess.DEVNULL) - elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], - stdout=subprocess.DEVNULL) - assert ret == 0 + temp_file = tmp_path / "test_exec.hatchet" + name = str(temp_file.with_suffix("")) + if mode == "script": + ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + elif mode == "python": + ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) + elif mode == "pytest": + ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) + assert ret == 0 + with temp_file.open() as f: data = json.load(f, ) - kernels = data[0]["children"] - assert len(kernels) == 2 - assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" + kernels = data[0]["children"] + assert len(kernels) == 2 + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" + + +def test_instrument_exec(): + + try: + out = subprocess.Popen(["proton", "--instrument=print-mem-spaces", "instrument.py"], stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + except Exception as e: + print(f"An error occurred while executing proton: {e}") + + result = [] + for line in str(out.stderr.read().decode()).split("\n"): + if line: + result.append(line.split()) + + if is_hip(): + assert [row[0] for row in result] == ['0', '1', '2', '3'] + assert [row[1] for row in result] == ['matmul_kernel', 'matmul_kernel', 'matmul_kernel', 'matmul_kernel'] + assert [row[2] for row in result + ] == ['instrument.py:32:20', 'instrument.py:33:20', 'instrument.py:32:20', 'instrument.py:33:20'] + assert [row[3] for row in result] == ['SHARED', 'SHARED', 'SHARED', 'SHARED'] + assert [row[4] for row in result] == ['STORE', 'STORE', 'LOAD', 'LOAD'] + else: + assert [row[0] for row in result] == ['0'] + assert [row[1] for row in result] == ['matmul_kernel'] + assert [row[2] for row in result] == ['instrument.py:42:21'] + assert [row[3] for row in result] == ['SHARED'] + assert [row[4] for row in result] == ['LOAD'] diff --git a/third_party/proton/test/test_lib.py b/third_party/proton/test/test_lib.py index 0380268c0454..4a8313660b3f 100644 --- a/third_party/proton/test/test_lib.py +++ b/third_party/proton/test/test_lib.py @@ -1,6 +1,6 @@ -import triton._C.libproton.proton as libproton -import tempfile import pathlib + +import triton._C.libproton.proton as libproton from triton.profiler.profile import _select_backend @@ -10,6 +10,11 @@ def test_record(): assert id1 == id0 + 1 +def test_state(): + libproton.enter_state("zero") + libproton.exit_state() + + def test_scope(): id0 = libproton.record_scope() libproton.enter_scope(id0, "zero") @@ -25,22 +30,22 @@ def test_op(): libproton.exit_op(id0, "zero") -def test_session(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - libproton.deactivate(session_id) - libproton.activate(session_id) - libproton.finalize(session_id, "hatchet") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() - - -def test_add_metrics(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - libproton.start(f.name.split(".")[0], "shadow", "tree", _select_backend()) - id1 = libproton.record_scope() - libproton.enter_scope(id1, "one") - libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) - libproton.exit_scope(id1, "one") - libproton.finalize_all("hatchet") - assert pathlib.Path(f.name).exists() +def test_session(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_session.hatchet" + session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + libproton.deactivate(session_id) + libproton.activate(session_id) + libproton.finalize(session_id, "hatchet") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_add_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_add_metrics.hatchet" + libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) + libproton.exit_scope(id1, "one") + libproton.finalize_all("hatchet") + assert temp_file.exists() diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 13cb9bd99cbe..5ed5cfce4145 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -1,12 +1,13 @@ import torch import triton import triton.profiler as proton -import tempfile import json import pytest from typing import NamedTuple +import pathlib import triton.language as tl +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME def is_hip(): @@ -14,30 +15,31 @@ def is_hip(): @pytest.mark.parametrize("context", ["shadow", "python"]) -def test_torch(context): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context=context) - proton.enter_scope("test") - torch.ones((2, 2), device="cuda") - proton.exit_scope() - proton.finalize() +def test_torch(context, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_torch.hatchet" + proton.start(str(temp_file.with_suffix("")), context=context) + proton.enter_scope("test") + torch.ones((2, 2), device="cuda") + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: data = json.load(f) - if context == "shadow": - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test" - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 - elif context == "python": - assert len(data[0]["children"]) == 1 - # The last frame is the torch kernel - prev_frame = data - curr_frame = data[0]["children"] - while len(curr_frame) > 0: - prev_frame = curr_frame - curr_frame = curr_frame[0]["children"] - assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] - - -def test_triton(): + if context == "shadow": + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + elif context == "python": + assert len(data[0]["children"]) == 1 + # The last frame is the torch kernel + prev_frame = data + curr_frame = data[0]["children"] + while len(curr_frame) > 0: + prev_frame = curr_frame + curr_frame = curr_frame[0]["children"] + assert "elementwise_kernel" in prev_frame[0]["frame"]["name"] + + +def test_triton(tmp_path: pathlib.Path): @triton.jit def foo(x, y): @@ -45,23 +47,24 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0"): - with proton.scope("test1"): - foo[(1, )](x, y) - with proton.scope("test2"): + temp_file = tmp_path / "test_triton.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0"): + with proton.scope("test1"): foo[(1, )](x, y) - proton.finalize() + with proton.scope("test2"): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 2 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert len(data[0]["children"][0]["children"]) == 1 - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" - assert data[0]["children"][1]["frame"]["name"] == "test2" + assert len(data[0]["children"]) == 2 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert len(data[0]["children"][0]["children"]) == 1 + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" + assert data[0]["children"][1]["frame"]["name"] == "test2" -def test_cudagraph(): +def test_cudagraph(tmp_path: pathlib.Path): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -75,65 +78,47 @@ def fn(): c = a + b foo[(1, )](a, b, c) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], context="shadow") + temp_file = tmp_path / "test_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow") - # warmup - # four kernels - fn() + # warmup + # four kernels + fn() - # no kernels - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(10): - fn() + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() - proton.enter_scope("test") - g.replay() - g.reset() - torch.cuda.synchronize() - proton.exit_scope() - proton.finalize() + proton.enter_scope("test") + g.replay() + g.reset() + torch.cuda.synchronize() + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: data = json.load(f) - # CUDA/HIP graph may also invoke additional kernels to reset outputs - # {torch.ones, add, foo, test} - assert len(data[0]["children"]) >= 4 - # find the test frame - test_frame = None - for child in data[0]["children"]: - if child["frame"]["name"] == "test": - test_frame = child - break - assert test_frame is not None - # {torch.ones, add, foo} - if is_hip(): - assert len(test_frame["children"]) >= 2 - else: - assert len(test_frame["children"]) >= 3 - assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 - - -def test_metrics(): - - @triton.jit - def foo(x, y): - tl.store(y, tl.load(x)) - - x = torch.tensor([2], device="cuda") - y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.finalize() - data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test_frame["children"]) >= 2 + else: + assert len(test_frame["children"]) >= 3 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 -def test_metrics_ignore(): +def test_metrics(tmp_path: pathlib.Path): @triton.jit def foo(x, y): @@ -141,36 +126,38 @@ def foo(x, y): x = torch.tensor([2], device="cuda") y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0]) - proton.deactivate(session_id) - with proton.scope("test0", {"foo": 1.0}): - foo[(1, )](x, y) - proton.activate(session_id) - proton.finalize() + temp_file = tmp_path / "test_metrics.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 0 - - -def test_scope_backward(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0]) - with proton.scope("ones1"): - a = torch.ones((100, 100), device="cuda", requires_grad=True) - with proton.scope("plus"): - a2 = a * a * a - with proton.scope("ones2"): - loss = torch.ones_like(a2) - - # Backward triggers two kernels in a single scope - with proton.scope("backward"): - a2.backward(loss) - proton.finalize() + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + + +def test_scope_backward(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_backward.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("ones1"): + a = torch.ones((100, 100), device="cuda", requires_grad=True) + with proton.scope("plus"): + a2 = a * a * a + with proton.scope("ones2"): + loss = torch.ones_like(a2) + + # Backward triggers two kernels in a single scope + with proton.scope("backward"): + a2.backward(loss) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 4 + assert len(data[0]["children"]) == 4 -def test_hook(): +def test_hook(tmp_path: pathlib.Path): def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): # get arg's element size @@ -187,20 +174,55 @@ def foo(x, size: tl.constexpr, y): x = torch.tensor([2], device="cuda", dtype=torch.float32) y = torch.zeros_like(x) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton") - with proton.scope("test0"): - foo[(1, )](x, 1, y, num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - assert len(data[0]["children"]) == 1 - assert data[0]["children"][0]["frame"]["name"] == "test0" - assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" - assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 - assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" + assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 -def test_pcsampling(): +@pytest.mark.parametrize("context", ["shadow", "python"]) +def test_hook_gpu_kernel(tmp_path: pathlib.Path, context: str): + tmp_path = pathlib.Path("./") + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + # A gpu kernel, but it should be under the metadata state + return {"name": "foo_test", "bytes": x.sum().item()} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", context=context) + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + # bfs search until find the reduce kernel and then check its parent + queue = [data[0]] + while len(queue) > 0: + parent_frame = queue.pop(0) + for child in parent_frame["children"]: + if "reduce" in child["frame"]["name"]: + assert parent_frame["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME + return + queue.append(child) + + +def test_pcsampling(tmp_path: pathlib.Path): if is_hip(): pytest.skip("HIP backend does not support pc sampling") @@ -214,37 +236,59 @@ def foo(x, y, size: tl.constexpr): for _ in range(1000): tl.store(y + offs, tl.load(x + offs)) - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - proton.start(f.name.split(".")[0], hook="triton", backend="cupti_pcsampling") - with proton.scope("init"): - x = torch.ones((1024, ), device="cuda", dtype=torch.float32) - y = torch.zeros_like(x) - with proton.scope("test"): - foo[(1, )](x, y, x.size()[0], num_warps=4) - proton.finalize() + temp_file = tmp_path / "test_pcsampling.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - init_frame = data[0]["children"][0] - test_frame = data[0]["children"][1] - # With line mapping - assert "foo" in test_frame["children"][0]["frame"]["name"] - assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 - assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] - # Without line mapping - assert "elementwise" in init_frame["children"][0]["frame"]["name"] - assert init_frame["children"][0]["metrics"]["num_samples"] > 0 - - -def test_deactivate(): - with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: - session_id = proton.start(f.name.split(".")[0], hook="triton") - proton.deactivate(session_id) - torch.randn((10, 10), device="cuda") - proton.activate(session_id) - torch.zeros((10, 10), device="cuda") - proton.deactivate(session_id) - proton.finalize() + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_deactivate.hatchet" + session_id = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + with temp_file.open() as f: data = json.load(f) - # Root shouldn't have device id - assert "device_id" not in data[0]["metrics"] - assert len(data[0]["children"]) == 1 - assert "device_id" in data[0]["children"][0]["metrics"] + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"] + + +def test_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_multiple_sessions0.hatchet" + temp_file1 = tmp_path / "test_multiple_sessions1.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + torch.randn((10, 10), device="cuda") + torch.randn((10, 10), device="cuda") + proton.deactivate(session_id0) + proton.finalize(session_id0) + torch.randn((10, 10), device="cuda") + proton.finalize(session_id1) + # kernel has been invokved twice in session 0 and three times in session 1 + with temp_file0.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 2 + with temp_file1.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 3 diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index b2d4d39f9b37..13ea1d39b4f2 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -1,12 +1,15 @@ import pytest import subprocess from triton.profiler.viewer import get_min_time_flops, get_min_time_bytes, get_raw_metrics, format_frames, derive_metrics, filter_frames +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME import numpy as np file_path = __file__ -cuda_example_file = file_path.replace("test_viewer.py", "example_cuda.json") -hip_example_file = file_path.replace("test_viewer.py", "example_hip.json") -frame_example_file = file_path.replace("test_viewer.py", "example_frame.json") +triton_example_file = file_path.replace("test_viewer.py", "examples/triton.json") +cuda_example_file = file_path.replace("test_viewer.py", "examples/cuda.json") +hip_example_file = file_path.replace("test_viewer.py", "examples/hip.json") +frame_example_file = file_path.replace("test_viewer.py", "examples/frame.json") +leaf_example_file = file_path.replace("test_viewer.py", "examples/leaf_nodes.json") def test_help(): @@ -15,6 +18,21 @@ def test_help(): assert ret == 0 +def test_sort(): + with open(leaf_example_file, "r") as f: + gf, raw_metrics, device_info = get_raw_metrics(f) + gf = format_frames(gf, None) + gf.update_inclusive_columns() + metrics = ["time/s", "time/ms", "time/us", "time/ns"] + metrics = derive_metrics(gf, metrics, raw_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:5]['name'].values + expected = ['ROOT', 'kernel_1_1_1', 'kernel_3_1_1', 'kernel_3_2_2', 'kernel_1_2_2'] + assert len(actual) == len(expected) + assert all(a == b for a, b in zip(actual, expected)) + + @pytest.mark.parametrize("option", ["full", "file_function_line", "function_line", "file_function"]) def test_format_frames(option): with open(frame_example_file, "r") as f: @@ -48,6 +66,15 @@ def test_filter_frames(option): assert idx.sum() == 1 +def test_filter_metadata(): + with open(triton_example_file, "r") as f: + gf, _, _ = get_raw_metrics(f) + assert COMPUTE_METADATA_SCOPE_NAME not in gf.dataframe["name"].tolist() + assert "cuda_kernel" not in gf.dataframe["name"].tolist() + assert "scope" in gf.dataframe["name"].tolist() + assert "triton_kernel" in gf.dataframe["name"].tolist() + + def test_min_time_flops(): with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) diff --git a/third_party/sleef b/third_party/sleef new file mode 160000 index 000000000000..93f04d869471 --- /dev/null +++ b/third_party/sleef @@ -0,0 +1 @@ +Subproject commit 93f04d869471ce4d007abaebb8c6a7bc62749f61 diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index c27c63335e2b..7a34955a9638 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -43,7 +43,7 @@ std::string strReplace(std::string s, const std::string &from, // We use some abbreviations when spelling out MLIR types. std::string expandTyStr(std::string s) { s = strReplace(s, "T<", "tensor<"); - s = strReplace(s, "#B", "#triton_gpu.blocked"); + s = strReplace(s, "#B", "#ttg.blocked"); s = strReplace(s, "spt", "sizePerThread"); s = strReplace(s, "tpw", "threadsPerWarp"); s = strReplace(s, "wpc", "warpsPerCTA"); @@ -620,7 +620,135 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); } -} // anonymous namespace +class LinearEncodingTest : public ::testing::Test { +public: + LinearEncodingTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { + // Define a tensor shape + auto rank = 2; + SmallVector> shapes = {{64, 128}, {256, 1024}}; + SmallVector> orders = {{0, 1}, {1, 0}}; + SmallVector ctaLayouts = { + triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), + triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), + }; + SmallVector distributedEncodings; + + // Create BlockedEncodingAttr and SliceEncodingAttr + { + SmallVector sizePerThread = {4, 4}; + SmallVector threadsPerWarp = {4, 8}; + SmallVector warpsPerCTA = {2, 2}; + + for (auto ctaLayout : ctaLayouts) { + for (const auto &order : orders) { + auto blockedEncoding = triton::gpu::BlockedEncodingAttr::get( + &ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + distributedEncodings.push_back(blockedEncoding); + distributedEncodings.push_back( + triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); + } + } + } + + // Create an MMAv2 and DotOperandEncodingAttr (MMAv3 doesn't support linear + // layouts yet) + { + unsigned versionMajor = 2; + unsigned versionMinor = 0; + SmallVector warpsPerCTA{4, 2}; + SmallVector instrShape{16, 8}; // Instruction shape (M, N) + auto mma = triton::gpu::NvidiaMmaEncodingAttr::get( + &ctx, versionMajor, versionMinor, warpsPerCTA, ctaLayouts[0], + instrShape); + distributedEncodings.push_back(mma); + // Create an opIdx=0 and opIdx=1 encoding + for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { + distributedEncodings.push_back( + triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, mma, 2)); + } + } + + for (const auto &distributedEncoding : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = + dyn_cast(distributedEncoding)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + + // Create LinearEncodingAttr from the LinearLayout + auto linearLayout = *distributedEncoding.toLinearLayout(shape); + auto linearEncoding = + triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout); + + // Test that the canonical form of the LinearLayout is indeed canonical + // by expanding it to the original shape + auto expandedLL = linearEncoding.toLinearLayout(shape); + ASSERT_EQ(linearLayout, expandedLL); + + // Test that methods of DistributedEncoding return the same values + Type eltTy = FloatType::getF32(&ctx); + + ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); + ASSERT_EQ(cast(distributedEncoding) + .getTotalElemsPerThread(shape, eltTy), + linearEncoding.getTotalElemsPerThread(shape, eltTy)); + ASSERT_EQ(cast(distributedEncoding) + .getElemsPerThread(shape, eltTy), + linearEncoding.getElemsPerThread(shape, eltTy)); + ASSERT_EQ(distributedEncoding.getRepOrder(), + linearEncoding.getRepOrder()); + ASSERT_EQ(distributedEncoding.getContigPerThread(), + linearEncoding.getContigPerThread()); + // DotOperandEncodingAttr::getWarpOrder() is not defined + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpOrder(), + linearEncoding.getWarpOrder()); + } + ASSERT_EQ(distributedEncoding.getThreadOrder(), + linearEncoding.getThreadOrder()); + // For slice these do not equal the total number of lines / warps + // See [Note. Divergence of methods wrt. legacy layouts] + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), + linearEncoding.getWarpsPerCTA()); + ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), + linearEncoding.getThreadsPerWarp()); + } + // Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes + // the second repetition along K as the second tile. + if (!isa(distributedEncoding)) { + // FIXME: This happens to be correct for SliceLayout because of the hack + // in SliceEncodingAttr::toLinearLayout(). We should remove the hack + // and the skips in the getWarpsPerCTA() and getThreadsPerWarp() + ASSERT_EQ(distributedEncoding.getSizePerThread(), + linearEncoding.getSizePerThread()); + } + + // block level + // SliceEncoding is not well-defined for CGAs + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getCTASplitNum(), + linearEncoding.getCTASplitNum()); + ASSERT_EQ(distributedEncoding.getCTAsPerCGA(), + linearEncoding.getCTAsPerCGA()); + // If we are not using CGAs, the order is meaningless + auto useCGA = distributedEncoding.getCTAsPerCGA() != + SmallVector(rank, 1); + if (useCGA) { + ASSERT_EQ(distributedEncoding.getCTAOrder(), + linearEncoding.getCTAOrder()); + } + } + } + } +} +} // namespace } // namespace mlir::triton::gpu int main(int argc, char *argv[]) { diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index fd65233e5c6b..af6242b59662 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,10 +41,16 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, - ArrayRef order) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); - return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, + ArrayRef instrShape, + ArrayRef numWarps) { + auto ctaLayout = CTALayoutAttr::getDefault(&ctx, numWarps.size()); + return NvidiaMmaEncodingAttr::get(&ctx, versionMaj, versionMin, numWarps, + std::move(ctaLayout), instrShape); + } + + DotOperandEncodingAttr dot(Attribute parent, int idx, int kWidth) { + return DotOperandEncodingAttr::get(&ctx, idx, parent, /*kWidth=*/kWidth); } AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, @@ -301,6 +307,19 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { + EXPECT_EQ(toLinearLayout({16, 16}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -378,8 +397,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) { } TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) { - SmallVector, 4> instrShapes = { - {16, 16, 8}, {16, 16, 8}, {16, 8, 8}}; + SmallVector, 2> instrShapes = {{16, 16, 8}, {16, 8, 8}}; for (auto instrShape : instrShapes) { SCOPED_TRACE(triton::join(instrShape, ",")); EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1}, @@ -502,7 +520,8 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + auto parent = mma(2, 0, {16, 8}, {1, 1}); + EXPECT_EQ(toLinearLayout({16, 64}, dot(parent, 0, 8)), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +530,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 8}, dot(parent, 1, 8)), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -523,8 +542,9 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {4, 1}); EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + toLinearLayout({128, 128}, dot(parent, 0, 8)), LinearLayout( { {S("register"), @@ -534,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({128, 64}, dot(parent, 1, 8)), LinearLayout( { {S("register"), @@ -542,19 +562,19 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {2, 0}, {4, 0}, {32, 0}, + {64, 0}, {0, 8}, {0, 16}, - {0, 32}, - {64, 0}}}, + {0, 32}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 1, 8)), LinearLayout( { {S("register"), @@ -569,13 +589,146 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_3D) { + // We implement one that exercises all the paths + auto parent = mma(2, 0, {1, 16, 8}, {2, 4, 2}); + EXPECT_EQ(toLinearLayout({16, 128, 128}, dot(parent, 0, 8)), + LinearLayout( + { + {S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 8, 0}, + {0, 0, 32}, + {0, 0, 64}, + {0, 64, 0}, + {2, 0, 0}, + {4, 0, 0}, + {8, 0, 0}}}, + {S("lane"), + {{0, 0, 8}, {0, 0, 16}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({8, 128, 64}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 0, 16}, + {0, 0, 32}, + {2, 0, 0}, + {4, 0, 0}}}, + {S("lane"), + {{0, 8, 0}, {0, 16, 0}, {0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + { + S("warp"), + {{0, 0, 8}, {0, 0, 0}, {0, 0, 0}, {1, 0, 0}}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_warp4_kwidth2) { + auto parent = mma(3, 0, {16, 16, 8}, {4, 1}); + auto dotOp = dot(parent, 0, 2); + + EXPECT_EQ(toLinearLayout({64, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 32}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_mixed_warp_kwidth4) { + // Testing dot with MMAv3 encoding for opIdx = 0 and kWidth = 4 + auto parent = mma(3, 0, {16, 16, 8}, {4, 2}); + auto dotOp = dot(parent, 0, 4); + + EXPECT_EQ(toLinearLayout({128, 64}, dotOp), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {8, 0}, {0, 16}, {0, 32}, {64, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {2, 2}); + EXPECT_EQ( + toLinearLayout({32, 64}, dot(parent, 0, 8)), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 16}, dot(parent, 1, 8)), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 0, 8)), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 32}, dot(parent, 1, 8)), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false); diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 8be680562cae..d6b94e83f012 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -410,26 +410,6 @@ TEST_F(LinearLayoutTest, InvertAndCompose_NonInjective) { EXPECT_EQ(composition.compose(l2), l1); } -TEST_F(LinearLayoutTest, InvertAndCompose_SmallerResult) { - // The domain of l2 is [0,16), but the codomain of the result is only [0,8), - // because there's no value v in the codomain of l1 such that l2^-1(v) >= 8. - LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); - LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {8}}}}, {S("out")}); - // Pseudo-inverse of l2 is - // - // out(1) = 2 - // out(2) = 4 - // out(4) = 1 - // out(8) = 8 - // - // Composing with l1 gives back l2^-1 without the out(8) entry. - LinearLayout composition = l1.invertAndCompose(l2); - EXPECT_EQ(composition, - LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, - /*requireSurjective=*/false)); - EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); -} - TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); @@ -514,8 +494,10 @@ TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims) { LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); LinearLayout c = l1.invertAndCompose(l2); - EXPECT_EQ(c, LinearLayout::identity1D(8, S("in1"), S("in3")) * - LinearLayout::identity1D(2, S("in2"), S("in4"))); + EXPECT_EQ(c, LinearLayout( + {{S("in1"), {{1, 0}, {2, 0}, {4, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 8}, {S("in4"), 2}}, + /*requireSurjective=*/false)); EXPECT_EQ(c.compose(l2), l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); } @@ -525,8 +507,9 @@ TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims2) { LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); LinearLayout c = a.invertAndCompose(b); EXPECT_EQ(c, - LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 1}}}}, - {S("in3"), S("in4")})); + LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 4}, {S("in4"), 2}}, + /*requireSurjective=*/false)); EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); } @@ -561,104 +544,6 @@ TEST_F(LinearLayoutTest, NumConsecutiveInOut) { .getNumConsecutiveInOut()); } -TEST_F(LinearLayoutTest, DivideRight_Simple) { - EXPECT_EQ(LinearLayout::identity1D(8, S("in"), S("out")) - .divideRight(LinearLayout::identity1D(4, S("in"), S("out"))), - LinearLayout::identity1D(2, S("in"), S("out"))); - - EXPECT_EQ(LinearLayout::identity1D(8, S("in"), S("out")) - .divideRight(LinearLayout::identity1D(8, S("in"), S("out"))), - LinearLayout::empty()); -} - -TEST_F(LinearLayoutTest, DivideRight_2D) { - LinearLayout l1( - { - {S("in1"), {{1, 1}, {2, 2}, {0, 8}, {0, 4}}}, - {S("in2"), {{0, 2}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{2}, {1}}}}, {S("out2")}); - LinearLayout l3( - { - {S("in1"), {{1, 1}, {2, 2}}}, - {S("in2"), {{0, 2}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - ASSERT_EQ(l1.divideRight(l2), l3); - EXPECT_EQ(l1.divideRight(l2).value() * l2, l1); -} - -TEST_F(LinearLayoutTest, DivideRight_EliminateInDim) { - LinearLayout l1( - { - {S("in2"), {{0, 1}, {1, 0}}}, - {S("in1"), {{2, 0}, {0, 2}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{1, 0}, {0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l3({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")}); - ASSERT_EQ(l3 * l2, l1); - EXPECT_EQ(l1.divideRight(l2), l3); - - LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}}, - {S("out1"), S("out2")}); - LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); - LinearLayout l6({{S("in2"), {}}}, {S("out1"), S("out2")}); - ASSERT_EQ(l5 * l6, l4); - EXPECT_EQ(l4.divideRight(l6), l5); - - LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}}, - {S("out1"), S("out2")}); - LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}}, {}); - ASSERT_EQ(l9 * l8, l7); - EXPECT_EQ(l7.divideRight(l8), l9); -} - -TEST_F(LinearLayoutTest, DivideRight_EliminateOutDim) { - LinearLayout l1( - { - {S("in2"), {{1, 0}, {1, 0}}}, - {S("in1"), {{2, 0}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{1, 0}, {0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l3({{S("in2"), {{1}, {1}}}}, {S("out1")}); - ASSERT_EQ(l3 * l2, l1); - EXPECT_EQ(l1.divideRight(l2), l3); - - LinearLayout l4( - { - {S("in1"), {{0, 1}, {0, 2}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l5({{S("in1"), {{1}, {2}}}}, {S("out2")}); - using BasesArray = - ArrayRef>>>; - LinearLayout l6(BasesArray{}, {S("out1")}); - ASSERT_EQ(l6 * l5, l4); - EXPECT_EQ(l4.divideRight(l5), l6); -} - -TEST_F(LinearLayoutTest, DivideRight_Assertion) { - LinearLayout l1({{S("register"), - {{0, 1, 0, 0}, {0, 2, 0, 0}, {0, 0, 2, 0}, {1, 0, 0, 0}}}, - {S("lane"), - {{0, 4, 0, 0}, - {0, 8, 0, 0}, - {0, 16, 0, 0}, - {0, 0, 1, 0}, - {2, 0, 0, 0}}}, - {S("warp"), {{4, 0, 0, 0}, {8, 0, 0, 0}}}, - {S("block"), {}}}, - {S("register"), S("lane"), S("warp"), S("block")}); - LinearLayout l2 = LinearLayout::identity1D(32, S("lane"), S("lane")) * - LinearLayout::identity1D(4, S("warp"), S("warp")) * - LinearLayout::identity1D(1, S("block"), S("block")); - EXPECT_EQ(l1.divideRight(l2), std::nullopt); -} - TEST_F(LinearLayoutTest, EqualsChecksOutDimSizes) { EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) == LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, @@ -710,52 +595,33 @@ TEST_F(LinearLayoutTest, SublayoutIsZero) { EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out2")})); } -TEST_F(LinearLayoutTest, SublayoutIsIdentity) { - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({S("in")}, {S("out")})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({}, {S("out")})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({S("in")}, {})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({}, {})); +TEST_F(LinearLayoutTest, SquareSublayoutIsIdentity) { + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({S("in")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({})); LinearLayout l1( {{S("in1"), {{1, 1}, {2, 2}, {4, 4}}}, {S("in2"), {{2, 1}, {1, 2}}}}, - {{S("out1"), 8}, {S("out2"), 8}}, /*requireSurjective=*/false); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out1"), S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out2"), S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in2")}, {S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in2")}, {S("out2")})); - - LinearLayout l2 = - LinearLayout::identity1D(4, S("in1"), S("out1")) * - LinearLayout::identity1D(8, S("in2"), S("out2")) * - LinearLayout({{S("in3"), {{1, 1, 1}}}}, - {{S("out1"), 2}, {S("out2"), 2}, {S("out3"), 2}}, - /*requireSurjective=*/false); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in2")}, {S("out2")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in3")}, {S("out3")})); - EXPECT_FALSE( - l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")})); - EXPECT_FALSE(l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1"), S("in3")}, {S("out1")})); - - LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("out1")) * - LinearLayout::identity1D(8, S("in2"), S("out2")); - EXPECT_TRUE(l3.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l3.sublayoutIsIdentity({S("in2")}, {S("out2")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1")}, {S("out2")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in2")}, {S("out1")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")})); - EXPECT_TRUE( - l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")})); + {{S("in1"), 8}, {S("in2"), 8}}, /*requireSurjective=*/false); + EXPECT_TRUE(l1.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l1.squareSublayoutIsIdentity({S("in2")})); + + LinearLayout l2 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")) * + LinearLayout({{S("in3"), {{1, 1, 1}}}}, + {{S("in1"), 2}, {S("in2"), 2}, {S("in3"), 2}}, + /*requireSurjective=*/false); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l2.squareSublayoutIsIdentity({S("in3")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1"), S("in2")})); + + LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1"), S("in2")})); } TEST_F(LinearLayoutTest, FreeVariableMasks) { @@ -788,6 +654,81 @@ TEST_F(LinearLayoutTest, FreeVariableMasks) { AR({{S("in1"), 0b100}, {S("in2"), 0b10}})); } +TEST_F(LinearLayoutTest, QuotientOneDimension) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}}}, + {S("dim2"), {{0, 0}}}, + }, + {{S("dim1"), 2}, {S("dim2"), 1}}, /*requireSurjective=*/false); + + // Quotient over dim1, which is trivial + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_TRUE(quotientLayout.has_value()); + EXPECT_EQ(*quotientLayout, LinearLayout::zeros1D(2, S("dim2"), S("dim2"))); + // dim2 is zero, not the identity + ASSERT_FALSE(quotientLayout->quotient({S("dim2")}).has_value()); +} + +TEST_F(LinearLayoutTest, QuotientSeveralDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("dim2"), {{0, 1}, {0, 2}}}, + }, + {S("dim1"), S("dim2")}); + + auto quotientLayout = layout.quotient({S("dim1"), S("dim2")}); + EXPECT_TRUE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientMultipleTrivialDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 0, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // Quotient over dim2 is trivial, even if there's some funny business + // going on in the other dimensions + auto quotientLayout = layout.quotient({S("dim2")}); + ASSERT_TRUE(quotientLayout.has_value()); + + layout = LinearLayout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 1, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // As soon as one maps into the dimension being quotiented or out of it + // (in this case dim3 depends on dim2), we cannot quotient + quotientLayout = layout.quotient({S("dim2")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientEmptyLayout) { + LinearLayout layout = LinearLayout::empty(); + + // Quotienting over a dimension that doesn't exist is invalid + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { + // Test quotient on identity layout with multiple dimensions + LinearLayout layout = LinearLayout::identity1D(8, S("dim1"), S("dim1")) * + LinearLayout::identity1D(2, S("dim2"), S("dim2")) * + LinearLayout::identity1D(4, S("dim3"), S("dim3")); + + // We can quotient over all dimensions in any order + auto quotientLayout = layout.quotient({S("dim1"), S("dim3")}); + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); +} } // anonymous namespace } // namespace mlir::triton diff --git a/utils/generate-test-checks.py b/utils/generate-test-checks.py new file mode 100755 index 000000000000..3597d9150151 --- /dev/null +++ b/utils/generate-test-checks.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +=============================================================== +A script to generate FileCheck statements for mlir unit tests. +=============================================================== + +This script is a utility to add FileCheck patterns to an mlir file. + +NOTE: The input ``.mlir`` is expected to be the output from the parser, not a +stripped down variant. + +Example usage: + +.. code-block:: shell + + $ generate-test-checks.py foo.mlir + $ mlir-opt foo.mlir -transformation | generate-test-checks.py + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' + +The script will heuristically generate CHECK/CHECK-LABEL commands for each line +within the file. By default this script will also try to insert string +substitution blocks for all SSA value names. If ``--source file`` is specified, the +script will attempt to insert the generated CHECKs to the source file by looking +for line positions matched by ``--source_delim_regex``. + +The script is designed to make adding checks to a test case fast, it is *not* +designed to be authoritative about what constitutes a good test! +""" + +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import os # Used to advertise this file's name ("autogenerated_note"). +import re +import sys +from typing import Optional + +ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " +ADVERT_END = """ +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. +""" + +# Regex command to match an SSA identifier. +SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" +SSA_RE = re.compile(SSA_RE_STR) + +# Regex matching the left-hand side of an assignment +SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' +SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) + +# Regex matching attributes +ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' +ATTR_RE = re.compile(ATTR_RE_STR) + +# Regex matching the left-hand side of an attribute definition +ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' +ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) + + +# Class used to generate and manage string substitution blocks for SSA value +# names. +class VariableNamer: + + def __init__(self, variable_names): + self.scopes = [] + self.name_counter = 0 + + # Number of variable names to still generate in parent scope + self.generate_in_parent_scope_left = 0 + + # Parse variable names + self.variable_names = [name.upper() for name in variable_names.split(',')] + self.used_variable_names = set() + + # Generate the following 'n' variable names in the parent scope. + def generate_in_parent_scope(self, n): + self.generate_in_parent_scope_left = n + + # Generate a substitution name for the given ssa value name. + def generate_name(self, source_variable_name): + + # Compute variable name + variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' + if variable_name == '': + variable_name = "VAL_" + str(self.name_counter) + self.name_counter += 1 + + # Scope where variable name is saved + scope = len(self.scopes) - 1 + if self.generate_in_parent_scope_left > 0: + self.generate_in_parent_scope_left -= 1 + scope = len(self.scopes) - 2 + assert (scope >= 0) + + # Save variable + if variable_name in self.used_variable_names: + raise RuntimeError(variable_name + ': duplicate variable name') + self.scopes[scope][source_variable_name] = variable_name + self.used_variable_names.add(variable_name) + + return variable_name + + # Push a new variable name scope. + def push_name_scope(self): + self.scopes.append({}) + + # Pop the last variable name scope. + def pop_name_scope(self): + self.scopes.pop() + + # Return the level of nesting (number of pushed scopes). + def num_scopes(self): + return len(self.scopes) + + # Reset the counter and used variable names. + def clear_names(self): + self.name_counter = 0 + self.used_variable_names = set() + + +class AttributeNamer: + + def __init__(self, attribute_names): + self.name_counter = 0 + self.attribute_names = [name.upper() for name in attribute_names.split(',')] + self.map = {} + self.used_attribute_names = set() + + # Generate a substitution name for the given attribute name. + def generate_name(self, source_attribute_name): + + # Compute FileCheck name + attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' + if attribute_name == '': + attribute_name = "ATTR_" + str(self.name_counter) + self.name_counter += 1 + + # Prepend global symbol + attribute_name = '$' + attribute_name + + # Save attribute + if attribute_name in self.used_attribute_names: + raise RuntimeError(attribute_name + ': duplicate attribute name') + self.map[source_attribute_name] = attribute_name + self.used_attribute_names.add(attribute_name) + return attribute_name + + # Get the saved substitution name for the given attribute name, if it exists. + def get_name(self, source_attribute_name) -> Optional[str]: + return self.map.get(source_attribute_name) + + +# Return the number of SSA results in a line of type +# %0, %1, ... = ... +# The function returns 0 if there are no results. +def get_num_ssa_results(input_line): + m = SSA_RESULTS_RE.match(input_line) + return m.group().count('%') if m else 0 + + +# Process a line of input that has been split at each SSA identifier '%'. +def process_line(line_chunks, variable_namer): + output_line = "" + + # Process the rest that contained an SSA value name. + for chunk in line_chunks: + m = SSA_RE.match(chunk) + ssa_name = m.group(0) if m is not None else '' + + # Check if an existing variable exists for this name. + variable = None + for scope in variable_namer.scopes: + variable = scope.get(ssa_name) + if variable is not None: + break + + # If one exists, then output the existing name. + if variable is not None: + output_line += "%[[" + variable + "]]" + else: + # Otherwise, generate a new variable. + variable = variable_namer.generate_name(ssa_name) + output_line += "%[[" + variable + ":.*]]" + + # Append the non named group. + output_line += chunk[len(ssa_name):] + + return output_line.rstrip() + "\n" + + +# Process the source file lines. The source file doesn't have to be .mlir. +def process_source_lines(source_lines, note, args): + source_split_re = re.compile(args.source_delim_regex) + + source_segments = [[]] + for line in source_lines: + # Remove previous note. + if line == note: + continue + # Remove previous CHECK lines. + if line.find(args.check_prefix) != -1: + continue + # Segment the file based on --source_delim_regex. + if source_split_re.search(line): + source_segments.append([]) + + source_segments[-1].append(line + "\n") + return source_segments + + +def process_attribute_definition(line, attribute_namer, output): + m = ATTR_DEF_RE.match(line) + if m: + attribute_name = attribute_namer.generate_name(m.group(1)) + line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' + output.append(line) + + +def process_attribute_references(line, attribute_namer): + + output_line = '' + components = ATTR_RE.split(line) + for component in components: + m = ATTR_RE.match(component) + name = attribute_namer.get_name(m.group(1)) if m else None + if name is None: + output_line += component + else: + output_line += '#[[' + name + ']]' + output_line += component[len(m.group()):] + return output_line + + +# Pre-process a line of input to remove any character sequences that will be +# problematic with FileCheck. +def preprocess_line(line): + # Replace any double brackets, '[[' with escaped replacements. '[[' + # corresponds to variable names in FileCheck. + output_line = line.replace("[[", "{{\\[\\[}}") + + # Replace any single brackets that are followed by an SSA identifier, the + # identifier will be replace by a variable; Creating the same situation as + # above. + output_line = output_line.replace("[%", "{{\\[}}%") + + return output_line + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--check-prefix", default="CHECK", help="Prefix to use from check file.") + parser.add_argument("-o", "--output", nargs="?", type=argparse.FileType("w"), default=None) + parser.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + parser.add_argument( + "--source", + type=str, + help="Print each CHECK chunk before each delimeter line in the source" + "file, respectively. The delimeter lines are identified by " + "--source_delim_regex.", + ) + parser.add_argument("--source_delim_regex", type=str, default="func @") + parser.add_argument( + "--starts_from_scope", + type=int, + default=1, + help="Omit the top specified level of content. For example, by default " + 'it omits "module {"', + ) + parser.add_argument("-i", "--inplace", action="store_true", default=False) + parser.add_argument( + "--variable_names", type=str, default='', + help="Names to be used in FileCheck regular expression to represent SSA " + "variables in the order they are encountered. Separate names with commas, " + "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") + parser.add_argument( + "--attribute_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent " + "attributes in the order they are defined. Separate names with commas," + "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") + + args = parser.parse_args() + + # Open the given input file. + input_lines = [l.rstrip() for l in args.input] + args.input.close() + + # Generate a note used for the generated check file. + script_name = os.path.basename(__file__) + autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END + + source_segments = None + if args.source: + source_segments = process_source_lines([l.rstrip() for l in open(args.source, "r")], autogenerated_note, args) + + if args.inplace: + assert args.output is None + output = open(args.source, "w") + elif args.output is None: + output = sys.stdout + else: + output = args.output + + output_segments = [[]] + + # Namers + variable_namer = VariableNamer(args.variable_names) + attribute_namer = AttributeNamer(args.attribute_names) + + # Process lines + for input_line in input_lines: + if not input_line: + continue + + # Check if this is an attribute definition and process it + process_attribute_definition(input_line, attribute_namer, output_segments[-1]) + + # Lines with blocks begin with a ^. These lines have a trailing comment + # that needs to be stripped. + lstripped_input_line = input_line.lstrip() + is_block = lstripped_input_line[0] == "^" + if is_block: + input_line = input_line.rsplit("//", 1)[0].rstrip() + + cur_level = variable_namer.num_scopes() + + # If the line starts with a '}', pop the last name scope. + if lstripped_input_line[0] == "}": + variable_namer.pop_name_scope() + cur_level = variable_namer.num_scopes() + + # If the line ends with a '{', push a new name scope. + if input_line[-1] == "{": + variable_namer.push_name_scope() + if cur_level == args.starts_from_scope: + output_segments.append([]) + + # Result SSA values must still be pushed to parent scope + num_ssa_results = get_num_ssa_results(input_line) + variable_namer.generate_in_parent_scope(num_ssa_results) + + # Omit lines at the near top level e.g. "module {". + if cur_level < args.starts_from_scope: + continue + + if len(output_segments[-1]) == 0: + variable_namer.clear_names() + + # Preprocess the input to remove any sequences that may be problematic with + # FileCheck. + input_line = preprocess_line(input_line) + + # Process uses of attributes in this line + input_line = process_attribute_references(input_line, attribute_namer) + + # Split the line at the each SSA value name. + ssa_split = input_line.split("%") + + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. + if len(output_segments[-1]) != 0 or not ssa_split[0]: + output_line = "// " + args.check_prefix + ": " + # Pad to align with the 'LABEL' statements. + output_line += " " * len("-LABEL") + + # Output the first line chunk that does not contain an SSA name. + output_line += ssa_split[0] + + # Process the rest of the input line. + output_line += process_line(ssa_split[1:], variable_namer) + + else: + # Output the first line chunk that does not contain an SSA name for the + # label. + output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" + + # Process the rest of the input line on separate check lines. + output_line += "// " + args.check_prefix + "-SAME: " + output_line += process_line(ssa_split[1:], variable_namer) + + # Append the output line. + output_segments[-1].append(output_line) + + output.write(autogenerated_note + "\n") + + # Write the output. + if source_segments: + assert len(output_segments) == len(source_segments), (len(output_segments), len(source_segments)) + for check_segment, source_segment in zip(output_segments, source_segments): + for line in check_segment: + output.write(line) + for line in source_segment: + output.write(line) + else: + for segment in output_segments: + output.write("\n") + for output_line in segment: + output.write(output_line) + output.write("\n") + output.close() + + +if __name__ == "__main__": + main() diff --git a/zen5.patch b/zen5.patch new file mode 100644 index 000000000000..36c16fedc708 --- /dev/null +++ b/zen5.patch @@ -0,0 +1,31 @@ +diff --git a/python/src/llvm.cc b/python/src/llvm.cc +index 70e36449..a786a96a 100644 +--- a/python/src/llvm.cc ++++ b/python/src/llvm.cc +@@ -410,7 +410,7 @@ void init_triton_llvm(py::module &&m) { + auto target = + llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + std::unique_ptr machine{target->createTargetMachine( +- mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, ++ mod->getTargetTriple(), "znver4", "", {}, + llvm::Reloc::PIC_)}; + mod->setDataLayout(machine->createDataLayout()); + }); +@@ -437,7 +437,7 @@ void init_triton_llvm(py::module &&m) { + } + res = + translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(), +- llvm::sys::getHostCPUName().str(), "", {}, ++ "znver4", "", {}, + enable_fp_fusion, false, enable_fast_math); + } + return py::str(res); +@@ -553,7 +553,7 @@ void init_triton_llvm(py::module &&m) { + + m.def("get_cpu_tripple", []() { return llvm::sys::getProcessTriple(); }); + +- m.def("get_cpu_name", []() { return llvm::sys::getHostCPUName().str(); }); ++ m.def("get_cpu_name", []() { return "znver4"; }); + + m.def("get_cpu_features", []() { + auto features = llvm::sys::getHostCPUFeatures();