From d16800d135f0545f09a278f1316311db7db157c9 Mon Sep 17 00:00:00 2001 From: janEbert Date: Fri, 20 Feb 2026 20:24:01 +0100 Subject: [PATCH] Remove Triton cache manager Even the older LTS container uses a recent enough Triton version these days. Fix #3239. --- examples/mamba/run_text_gen_server_8b.sh | 1 - examples/mamba/train.sh | 1 - examples/multimodal/pretrain_mistral_clip.sh | 4 +- examples/multimodal/sft_mistral_clip.sh | 2 - megatron/core/ssm/triton_cache_manager.py | 81 -------------------- 5 files changed, 1 insertion(+), 88 deletions(-) delete mode 100644 megatron/core/ssm/triton_cache_manager.py diff --git a/examples/mamba/run_text_gen_server_8b.sh b/examples/mamba/run_text_gen_server_8b.sh index 8d3137f2442..b968e6b9bf0 100755 --- a/examples/mamba/run_text_gen_server_8b.sh +++ b/examples/mamba/run_text_gen_server_8b.sh @@ -18,7 +18,6 @@ export NCCL_IB_TIMEOUT=19 export NCCL_IB_QPS_PER_CONNECTION=4 export TRITON_CACHE_DIR="./triton-cache/" -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ --tensor-model-parallel-size 1 \ diff --git a/examples/mamba/train.sh b/examples/mamba/train.sh index 3952a997d47..725bd297226 100755 --- a/examples/mamba/train.sh +++ b/examples/mamba/train.sh @@ -42,7 +42,6 @@ mkdir -p ${DATACACHE_DIR} mkdir -p ${TENSORBOARD_DIR} export TRITON_CACHE_DIR="./triton-cache/" -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" SEQ_LEN=4096 TRAIN_SAMPLES=73242188 # 300B tokens / 4096 diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh index 4afcc0f2dad..01bfe1ddb3c 100755 --- a/examples/multimodal/pretrain_mistral_clip.sh +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -20,8 +20,6 @@ LOGS_DIR="${OUTPUT}/logs" TENSORBOARD_DIR="${OUTPUT}/tensorboard" export TRITON_CACHE_DIR="${WORKSPACE}/triton-cache/" -# The following patch to the Triton cache manager is needed for Triton version <= 3.1 -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" if [[ -z $LOAD_NAME ]]; then echo "Please set LOAD_NAME for input model name." @@ -129,4 +127,4 @@ OPTIONS=" \ export NVTE_APPLY_QK_LAYER_SCALING=0 export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} -torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} \ No newline at end of file +torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh index 6aac6821662..43bc4856eda 100755 --- a/examples/multimodal/sft_mistral_clip.sh +++ b/examples/multimodal/sft_mistral_clip.sh @@ -20,8 +20,6 @@ LOGS_DIR="${OUTPUT}/logs" TENSORBOARD_DIR="${OUTPUT}/tensorboard" export TRITON_CACHE_DIR="${WORKSPACE}/triton-cache/" -# The following patch to the Triton cache manager is needed for Triton version <= 3.1 -export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" if [[ -z $LOAD_NAME ]]; then echo "Please set LOAD_NAME for input model name." diff --git a/megatron/core/ssm/triton_cache_manager.py b/megatron/core/ssm/triton_cache_manager.py deleted file mode 100644 index 8c921dacbd8..00000000000 --- a/megatron/core/ssm/triton_cache_manager.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# Copyright 2018-2020 Philippe Tillet -# Copyright 2020-2022 OpenAI - -# Some of this code was adopted from https://github.com/triton-lang/triton -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import os -import uuid -from pathlib import Path - -try: - from triton import __version__ as triton_version - from triton.runtime.cache import FileCacheManager -except ImportError: - raise ImportError("triton is required by the Mamba model but cannot be imported") - - -def _version_no_greater_than(version, version_limit): - major, minor, _ = map(int, version.split('.')) - limit_major, limit_minor = map(int, version_limit.split('.')) - return major < limit_major or (major == limit_major and minor <= limit_minor) - - -def default_cache_dir(): - """Provides a default path for the Triton cache directory.""" - return os.path.join(Path.home(), ".triton", "cache") - - -class ParallelFileCacheManager(FileCacheManager): - """ - This patched version of ParallelFileCacheManager prevents errors related - to the builing of the Triton compiler cache when the number of model - parallel ranks is greater than one, including when certain types of file - system are used (such as Lustre). - - Usage: - export TRITON_CACHE_DIR= - export TRITON_CACHE_MANAGER=megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager - - This patch implements the changes in the following two Triton project pull - requests: - 1. https://github.com/triton-lang/triton/pull/3544 - 2. https://github.com/triton-lang/triton/pull/4295 - - The above changes will probably be included in Triton release version 3.2, - making this patch no longer necessary. - """ - - def put(self, data, filename, binary=True) -> str: - """A patched version of put, implementing PR 3544 and PR 4295.""" - patch_limit = '3.1' - assert _version_no_greater_than(triton_version, patch_limit), ( - "Assertion failed: ParallelFileCacheManager patch should not be " - f"used beyond Triton version {patch_limit}." - ) - if not self.cache_dir: - raise RuntimeError("Could not create or locate cache dir") - binary = isinstance(data, bytes) - if not binary: - data = str(data) - assert self.lock_path is not None - filepath = self._make_path(filename) - # Random ID to avoid any collisions - rnd_id = str(uuid.uuid4()) - # we use the PID in case a bunch of these around so we can see what PID made it - pid = os.getpid() - # use temp dir to be robust against program interruptions - temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") - os.makedirs(temp_dir, exist_ok=True) - temp_path = os.path.join(temp_dir, filename) - - mode = "wb" if binary else "w" - with open(temp_path, mode) as f: - f.write(data) - # Replace is guaranteed to be atomic on POSIX systems if it succeeds - # so filepath cannot see a partial write - os.replace(temp_path, filepath) - os.removedirs(temp_dir) - return filepath