Hi, thanks to your great work.I'm having some issues running your code. Can you give me some help?@YzzLiu
in function get_tp_fp_fn_tn()
‘’‘
def get_tp_fp_fn_tn(net_output, gt, disMap = None, axes=None, mask=None, square=False ,current_epoch = None):
smooth_trans = True
Tauo_st = 0
st_epoch = 1000
if smooth_trans == False:
if disMap != None:
disMap = disMap / torch.mean(disMap)
temp_disMap_value = disMap
if smooth_trans == True:
if disMap != None:
disMap = disMap / torch.mean(disMap)
if current_epoch <Tauo_st:
temp_disMap_value = torch.ones_like(disMap)
elif Tauo_st <= current_epoch < Tauo_st + st_epoch :
warm_start_matrix = torch.ones_like(disMap)
warm_para = float(Tauo_st + st_epoch - current_epoch) / st_epoch
temp_disMap_value = warm_para * warm_start_matrix + (1 - warm_para) * disMap
elif current_epoch >= Tauo_st + st_epoch:
temp_disMap_value = disMap
disMap = temp_disMap_value
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
shp_disMap = disMap.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
disMap = disMap.view((shp_disMap[0], 1, *shp_disMap[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
disMap2onehot = disMap.repeat_interleave(shp_x[1],dim = 1)
disMap2onehot = disMap2onehot.cuda(net_output.device.index)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
tp = torch.mul(tp, disMap2onehot)
fp = torch.mul(fp, disMap2onehot)
fn = torch.mul(fn, disMap2onehot)
tn = torch.mul(tn, disMap2onehot)
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tn = tn ** 2
if len(axes) > 0:
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
tn = sum_tensor(tn, axes, keepdim=False)
return tp, fp, fn, tn
Hi, thanks to your great work.I'm having some issues running your code. Can you give me some help?@YzzLiu
in function get_tp_fp_fn_tn()
‘’‘
def get_tp_fp_fn_tn(net_output, gt, disMap = None, axes=None, mask=None, square=False ,current_epoch = None):
smooth_trans = True
Tauo_st = 0
st_epoch = 1000
‘’‘
tp = torch.mul(tp, disMap2onehot) there is a bug:The size of tensor a (128) must match the size of tensor b (218) at non-singleton dimension 4.
i add ' disMap2onehot = torch.nn.functional.interpolate(disMap2onehot, size=tp.shape[2:])'
I don't know if it's the right thing to do.