Skip to content

During the training of the detection task, the Loss value becomes Nan. #4

@jpark0331

Description

@jpark0331

I would like to express my gratitude for your excellent work.
First, I confirmed that training was successful using the InternVIT-6B backbone and MMSegmentation.

I have encountered issues while training with the InternVIT-6B backbone and MMdetection.
During the training process, the loss values converge to NaN.

As follows:

/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

2024-04-17 08:37:31,145 - mmdet - INFO - Epoch [1][10/39089] lr: 3.751e-08, eta: 16 days, 22:58:00, time: 3.123, data_time: 0.366, memory: 24685, loss_rpn_cls: nan, loss_rpn_bbox: nan, loss_cls: nan, acc: 27.3672, loss_bbox: nan, loss: nan, grad_norm: nan
INFO:mmdet:Epoch [1][10/39089] lr: 3.751e-08, eta: 16 days, 22:58:00, time: 3.123, data_time: 0.366, memory: 24685, loss_rpn_cls: nan, loss_rpn_bbox: nan, loss_cls: nan, acc: 27.3672, loss_bbox: nan, loss: nan, grad_norm: nan

2024-04-17 08:37:55,822 - mmdet - INFO - Epoch [1][20/39089] lr: 7.918e-08, eta: 15 days, 4:14:40, time: 2.468, data_time: 0.030, memory: 24685, loss_rpn_cls: nan, loss_rpn_bbox: nan, loss_cls: nan, acc: 25.6029, loss_bbox: nan, loss: nan, grad_norm: nan
INFO:mmdet:Epoch [1][20/39089] lr: 7.918e-08, eta: 15 days, 4:14:40, time: 2.468, data_time: 0.030, memory: 24685, loss_rpn_cls: nan, loss_rpn_bbox: nan, loss_cls: nan, acc: 25.6029, loss_bbox: nan, loss: nan, grad_norm: nan

Additionally, upon tracing the flow of the code, the feature values from the VIT Backbone are derived correctly.
However, after the update for the first iteration,
the weight values of the up1, up2, up3, up4 layers in the Neck (FPN) are updated to INF value,
during the updating process. As a result, the loss values turn out to be NaN.

despite following the guide provided by MMdetection on solving the "Loss goes Nan" issue, problems still occur.
(https://mmdetection.readthedocs.io/en/v2.16.0/faq.html)

I look forward to your solutions. Thank you.

The settings I attempted are as follows:

2024-04-17 08:34:17,594 - mmdet - INFO - Environment info:

sys.platform: linux
Python: 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0]
CUDA available: True
GPU 0: NVIDIA A40
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.7, V11.7.99
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 1.13.1+cu117
PyTorch compiling details: PyTorch built with:

  • GCC 9.3
  • C++ Version: 201402
  • Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 11.7
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  • CuDNN 8.5
  • Magma 2.6.1
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.14.1+cu117
OpenCV: 4.9.0
MMCV: 1.7.0
MMCV Compiler: GCC 9.4
MMCV CUDA Compiler: 11.7
MMDetection: 2.25.3+7df6b87

2024-04-17 08:34:20,795 - mmdet - INFO - Distributed training: False
2024-04-17 08:34:23,783 - mmdet - INFO - Config:
model = dict(
type='FasterRCNN',
backbone=dict(
type='InternViT6B',
pretrain_size=224,
img_size=256,
patch_size=16,
embed_dim=3200,
depth=48,
num_heads=25,
mlp_ratio=4.0,
qkv_bias=False,
drop_path_rate=0.0,
init_values=0.1,
with_cp=True,
use_flash_attn=True,
qk_normalization=True,
layerscale_force_fp32=False,
with_fpn=True,
freeze_vit=True,
out_indices=[47],
window_attn=[
True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True
],
window_size=[
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16
],
output_dtype='float32',
pretrained='./pretrained/intern_vit_6b_224px.pth'),
neck=dict(
type='FPN',
in_channels=[3200, 3200, 3200, 3200],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))
dataset_type = 'CocoDataset'
data_root = '/DATA_17/DATASET/coco2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=3,
workers_per_gpu=2,
train=dict(
type='CocoDataset',
ann_file=
'/DATA_17/DATASET/coco2017/annotations/instances_train2017.json',
img_prefix='/DATA_17/DATASET/coco2017/train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]),
val=dict(
type='CocoDataset',
ann_file='/DATA_17/DATASET/coco2017/annotations/instances_val2017.json',
img_prefix='/DATA_17/DATASET/coco2017/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='CocoDataset',
ann_file='/DATA_17/DATASET/coco2017/annotations/instances_val2017.json',
img_prefix='/DATA_17/DATASET/coco2017/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]))
evaluation = dict(metric=['bbox'], interval=1, save_best='auto')
optimizer = dict(
type='AdamW',
lr=1.25e-05,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='CustomLayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=48, layer_decay_rate=1.0))
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='poly',
warmup='linear',
warmup_iters=3000,
warmup_ratio=1e-06,
power=1.0,
min_lr=0.0)
runner = dict(type='EpochBasedRunner', max_epochs=12)
checkpoint_config = dict(interval=1, max_keep_ckpts=2)
log_config = dict(interval=10, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='ToFloat16Hook', priority=49)]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
deepspeed = False
deepspeed_config = 'zero_configs/adam_zero1_fp16.json'
pretrained = './pretrained/intern_vit_6b_224px.pth'
work_dir = './work/'
auto_resume = False
gpu_ids = [0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions