diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index ce65bc4305..6d40f44b08 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR" pip install pytest==8.2.1 || error_exit "Failed to install pytest" -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed28367..454536358c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,32 +24,32 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..12439422c4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,16 +22,16 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" # debug tests @@ -42,9 +42,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e129..0e84a5ca5e 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,4 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index 8c3fdc8cdb..a5fd33cda9 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -9,7 +9,7 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..b29d1289f8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2751,7 +2751,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini new file mode 100644 index 0000000000..e90989721b --- /dev/null +++ b/tests/pytorch/pytest.ini @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[pytest] +filterwarnings= + error::RuntimeWarning + diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..9a1942f30d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -5,6 +5,7 @@ import math import os from typing import Dict, List, Tuple, Optional +import warnings import pytest import random @@ -1296,14 +1297,15 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ).eval() # Share params - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if bias: - te_linear_ref.bias = Parameter(te_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if bias: + te_linear_ref.bias = Parameter(te_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1359,12 +1361,13 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ).eval() # Share params - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) @@ -1601,17 +1604,18 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ).eval() # Share params - with torch.no_grad(): - ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) - if normalization != "RMSNorm": - ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) - ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) - if bias: - ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(ln_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - ln_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) + if bias: + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(ln_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + ln_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1739,19 +1743,24 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ).eval() # Share params - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) - if fuse_wgrad_accumulation: - ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) - ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() - ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) - ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + if fuse_wgrad_accumulation: + ln_mlp.fc1_weight.main_grad = torch.rand_like( + ln_mlp.fc1_weight, dtype=torch.float32 + ) + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() + ln_mlp.fc2_weight.main_grad = torch.rand_like( + ln_mlp.fc2_weight, dtype=torch.float32 + ) + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1796,14 +1805,15 @@ def test_layernorm_mlp_accuracy_checkpoint( ).eval() # Share params - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False) te_outputs_ref = _test_granular_accuracy( @@ -1952,9 +1962,13 @@ def test_grouped_linear_accuracy( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].module_setattr( + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].module_setattr( + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2096,9 +2110,13 @@ def test_grouped_linear_accuracy_save_original_input( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].module_setattr( + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].module_setattr( + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2298,8 +2316,7 @@ def test_padding_grouped_linear_accuracy( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - setattr( - ref_grouped_linear, + ref_grouped_linear.module_setattr( f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) @@ -2375,8 +2392,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - setattr( - ref_grouped_linear, + ref_grouped_linear.module_setattr( f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723cf..11c8af92f0 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -450,9 +450,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -560,9 +560,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a103..ad7dac108a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,6 +482,8 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + self._initialized = True + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): @@ -676,9 +678,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -703,7 +705,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -721,7 +723,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: @@ -1000,7 +1002,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, @@ -1145,10 +1147,11 @@ def forward( if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0a..143252640b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -335,6 +335,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -739,9 +740,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5497ee7967..b38725e8be 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -729,8 +729,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -2046,7 +2046,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) @@ -2057,7 +2057,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "Please initialize your model without the te.quantized_model_init(...) context." ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..322ad52723 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -935,7 +935,8 @@ def new_fwd(*user_args, **user_kwargs): forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules) if _order is None: - func.forward = forward + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + func.forward = forward ret.append(func) else: ret.append(forward) @@ -943,8 +944,9 @@ def new_fwd(*user_args, **user_kwargs): ret.append(graphed) backward_dw_func, reset_func = make_graphed_attribute_functions(i) - setattr(ret[-1], "backward_dw", backward_dw_func) - setattr(ret[-1], "reset", reset_func) + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + setattr(ret[-1], "backward_dw", backward_dw_func) + setattr(ret[-1], "reset", reset_func) if just_one_callable: return ret[0] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..eb83722512 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,7 +10,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing_extensions import Self from contextlib import contextmanager import logging from types import MethodType @@ -50,6 +51,8 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, + nvtx_range_push, + nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -602,10 +605,10 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -630,26 +633,33 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } + def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ + self.__dict__[name] = value + + def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ + super().__setattr__(name, value) def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) + if "_initialized" in self.__dict__ and self._initialized: + warnings.warn( + """The default implementation of torch.nn.Module introduces significant CPU overhead + when setting attributes and is therefore not recommended. Please use the explicit + calls (fast_setattr for setting regular values and module_setattr for setting + parameters, children modules and buffers).""", + RuntimeWarning, + ) + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -770,7 +780,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -927,7 +937,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -942,7 +952,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -971,48 +981,51 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + meta = self.fp8_meta + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled + + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True + self.init_fp8_meta_tensors(meta["recipe"]) + self.fast_setattr("fp8_initialized", True) - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1025,22 +1038,18 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepares for FWD execution.""" + self.fast_setattr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_setattr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1071,13 +1080,38 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + nvtx_range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp + def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + nvtx_range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" + yield self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + self.end_forward() + + def train(self, mode: bool = True) -> Self: + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + return super().train(mode) def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled @@ -1312,9 +1346,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): @@ -1513,7 +1547,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1527,14 +1560,14 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration - self.debug_last_iteration = TEDebugState.get_iteration() + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) if self.wgrad_store is not None: if debug and self.wgrad_store.delay_wgrad_compute(): @@ -1550,7 +1583,9 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True @@ -1562,22 +1597,13 @@ def no_debug_features_active(self, quantizers): def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1e6f0b00ab..ec1da1e02e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -614,7 +614,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms @@ -633,7 +633,6 @@ def __init__( ), "GroupedLinear doesn't support Userbuffer overlap." self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -716,6 +715,8 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True + self._initialized = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -793,60 +794,62 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: - weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + weight_tensors = self._get_weight_tensors() + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] - quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() - if debug: - if self.no_debug_features_active(list(chain(*quantizers))): - debug = False - quantizers = self._get_quantizers() + if debug: + if self.no_debug_features_active(list(chain(*quantizers))): + debug = False + quantizers = self._get_quantizers() - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") + if isinstance(weight_tensors, QuantizedTensorStorage): + raise RuntimeError("FP8 weights are not supported in debug mode.") - ( - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - ) = quantizers + ( + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + ) = quantizers - if is_grad_enabled: - linear_fn = _GroupedLinear.apply - autograd_ctx = [] - else: - linear_fn = _GroupedLinear.forward - autograd_ctx = [None] - - non_tensor_args = ( - m_splits, - self.apply_bias, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.sequence_parallel, - self.activation_dtype, - is_grad_enabled, - self, - None, # skip_fp8_weight_update - self.save_original_input, - debug, - ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + if is_grad_enabled: + linear_fn = _GroupedLinear.apply + autograd_ctx = [] + else: + linear_fn = _GroupedLinear.forward + autograd_ctx = [None] + + non_tensor_args = ( + m_splits, + self.apply_bias, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.sequence_parallel, + self.activation_dtype, + is_grad_enabled, + self, + None, # skip_fp8_weight_update + self.save_original_input, + debug, + ) + out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + + self.end_forward() if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 13b94f2327..2b821e38a5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1161,9 +1161,9 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1182,7 +1182,6 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name if tp_group is None: self.tp_size = tp_size @@ -1405,6 +1404,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self._initialized = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1514,87 +1515,89 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) - # Get concatenated weight and bias tensors - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + # Get concatenated weight and bias tensors + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - fwd_fn = _LayerNormLinear.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormLinear.forward - autograd_ctx = [None] - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_name, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - weight_tensor, - bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + fwd_fn = _LayerNormLinear.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormLinear.forward + autograd_ctx = [None] + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_name, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4256028c8b..0311092449 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1789,7 +1789,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1798,7 +1798,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1829,7 +1829,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1960,6 +1959,8 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + self._initialized = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -2052,115 +2053,117 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) - quantizers = ( - self._get_quantizers(fp8_output, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, is_grad_enabled) - # Get quantizers - ( - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - ) = quantizers - - # Get weight tensors - fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None - if not self.fp8: - if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.dequantize() - if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.dequantize() - - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers - if is_grad_enabled: - fwd_fn = _LayerNormMLP.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormMLP.forward - autograd_ctx = [None] - - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8 and not debug, - self.set_parallel_mode, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.activation, - self.activation_params, - self.normalization, - self.ub_overlap_ag, - self.ub_overlap_rs, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.gemm_gelu_fusion and not debug, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.checkpoint, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - fc1_weight, - fc1_bias, - fc2_weight, - fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + # Get weight tensors + fc1_weight, fc2_weight = self._get_weight_tensors() + fc1_bias = self.fc1_bias if self.use_bias else None + fc2_bias = self.fc2_bias if self.use_bias else None + if not self.fp8: + if isinstance(fc1_weight, Float8Tensor): + fc1_weight = fc1_weight.dequantize() + if isinstance(fc2_weight, Float8Tensor): + fc2_weight = fc2_weight.dequantize() + + # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): + self.fast_setattr("bias_gelu_nvfusion", False) + + if is_grad_enabled: + fwd_fn = _LayerNormMLP.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormMLP.forward + autograd_ctx = [None] + + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + self.bias_gelu_nvfusion and not self.fp8 and not debug, + self.set_parallel_mode, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.activation, + self.activation_params, + self.normalization, + self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.gemm_gelu_fusion and not debug, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.checkpoint, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + fc1_weight, + fc1_bias, + fc2_weight, + fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b8349f84a0..68c9758bbc 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,8 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -1098,7 +1098,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1111,7 +1111,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1309,6 +1308,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self._initialized = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1398,81 +1399,79 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] - - non_tensor_args = ( - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - ) - out = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] + + non_tensor_args = ( + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, + fp8_output, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, + debug, + ) + out = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, + non_tensor_args, + ) + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9b9ccc5185..1f856d386f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -398,6 +398,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, @@ -446,7 +447,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -463,7 +464,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -499,7 +500,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -768,9 +769,6 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())