From 658de5f1dde17d25db54fb07adf49370cc32d7c3 Mon Sep 17 00:00:00 2001 From: Marco Cipriano Date: Mon, 28 Jun 2021 17:27:25 -0700 Subject: [PATCH] removed hard-coded params --- models/TransBTS/PositionalEncoding.py | 2 +- models/TransBTS/TransBTS_downsample8x_skipconnection.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/TransBTS/PositionalEncoding.py b/models/TransBTS/PositionalEncoding.py index ae83d4d..ca4b483 100644 --- a/models/TransBTS/PositionalEncoding.py +++ b/models/TransBTS/PositionalEncoding.py @@ -25,7 +25,7 @@ class LearnedPositionalEncoding(nn.Module): def __init__(self, max_position_embeddings, embedding_dim, seq_length): super(LearnedPositionalEncoding, self).__init__() - self.position_embeddings = nn.Parameter(torch.zeros(1, 4096, 512)) #8x + self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, 512)) #8x def forward(self, x, position_ids=None): diff --git a/models/TransBTS/TransBTS_downsample8x_skipconnection.py b/models/TransBTS/TransBTS_downsample8x_skipconnection.py index 1a3389d..e7895e1 100644 --- a/models/TransBTS/TransBTS_downsample8x_skipconnection.py +++ b/models/TransBTS/TransBTS_downsample8x_skipconnection.py @@ -71,7 +71,7 @@ def __init__( padding=1 ) - self.Unet = Unet(in_channels=4, base_channels=16, num_classes=4) + self.Unet = Unet(in_channels=self.num_channels, base_channels=16, num_classes=4) self.bn = nn.BatchNorm3d(128) self.relu = nn.ReLU(inplace=True) @@ -198,7 +198,7 @@ def __init__( self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32) self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32) - self.endconv = nn.Conv3d(self.embedding_dim // 32, 4, kernel_size=1) + self.endconv = nn.Conv3d(self.embedding_dim // 32, self.num_classes, kernel_size=1) def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):