From 3d63815545a10860f7b6abc535d0da2afda39135 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Fri, 3 Feb 2023 17:27:57 +0000 Subject: [PATCH] Avoid 'UNetMidBlock2DCrossAttn' problem --- tuneavideo/models/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuneavideo/models/unet.py b/tuneavideo/models/unet.py index 6155eeb..41401e4 100644 --- a/tuneavideo/models/unet.py +++ b/tuneavideo/models/unet.py @@ -139,7 +139,7 @@ def __init__( self.down_blocks.append(down_block) # mid - if mid_block_type == "UNetMidBlock3DCrossAttn": + if mid_block_type in ["UNetMidBlock3DCrossAttn","UNetMidBlock2DCrossAttn"]: self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim,