diff --git a/tuneavideo/models/unet.py b/tuneavideo/models/unet.py index 6155eeb..41401e4 100644 --- a/tuneavideo/models/unet.py +++ b/tuneavideo/models/unet.py @@ -139,7 +139,7 @@ def __init__( self.down_blocks.append(down_block) # mid - if mid_block_type == "UNetMidBlock3DCrossAttn": + if mid_block_type in ["UNetMidBlock3DCrossAttn","UNetMidBlock2DCrossAttn"]: self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim,