Hello,
I'm trying to understand the flow of tensors through the model and at some point, the tensor fed to the unet during warmup has shape (1,4,8,64,64) even tough the unet.forward()'s doc mentions a required shape of (batch, channel, 1, height, width) (thus (8,4,1,64,64)). I suspect the docstring contains the wrong information because further down the road, the tensor is fed to an InflatedCon3D->forward() method that processes it correctly: expected shape is (b,c,f,h,w), where f is the number of frames (which is indeed 8).

Hello,
I'm trying to understand the flow of tensors through the model and at some point, the tensor fed to the unet during warmup has shape (1,4,8,64,64) even tough the unet.forward()'s doc mentions a required shape of (batch, channel, 1, height, width) (thus (8,4,1,64,64)). I suspect the docstring contains the wrong information because further down the road, the tensor is fed to an InflatedCon3D->forward() method that processes it correctly: expected shape is (b,c,f,h,w), where f is the number of frames (which is indeed 8).