From fa26805878c0dd855271599a4617b886926af361 Mon Sep 17 00:00:00 2001 From: fyu531 <2265141525@qq.com> Date: Sat, 14 Mar 2026 10:00:26 +0800 Subject: [PATCH] fix: resolve CUDA memory assert error --- sat/sample_video.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/sat/sample_video.py b/sat/sample_video.py index 6e60723..c31c966 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -351,21 +351,15 @@ def sampling_main(args, model_cls): # Decode latent serial to save GPU memory recons = [] - loop_num = (T - 1) // 2 + loop_num = T for i in range(loop_num): - if i == 0: - start_frame, end_frame = 0, 3 - else: - start_frame, end_frame = i * 2 + 1, i * 2 + 3 - if i == loop_num - 1: - clear_fake_cp_cache = True - else: - clear_fake_cp_cache = False + start_frame, end_frame = i, i+1 + clear_fake_cp_cache = (i == loop_num - 1) with torch.no_grad(): recon = first_stage_model.decode( - latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + latent[:, :, start_frame:end_frame].contiguous().to(torch.bfloat16), + clear_fake_cp_cache=clear_fake_cp_cache ) - recons.append(recon) recon = torch.cat(recons, dim=2).to(torch.float32)