-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
193 lines (161 loc) · 8.43 KB
/
eval.py
File metadata and controls
193 lines (161 loc) · 8.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
from sklearn.metrics import roc_curve, roc_auc_score, balanced_accuracy_score, classification_report, confusion_matrix
import os
import numpy as np
def det_auc(model, testloader, device='cpu', batch_classifier=None): #wsi level auc
all_scores = []
all_labels = []
for data, wsi_label in testloader:
data = data.to(device).squeeze(0) # Assuming each batch contains one WSI (one wsi, batch = 1)
wsi_label = wsi_label.item()
with torch.no_grad():
logit = model(data) #(n, 2)
if batch_classifier == None:
patch_pred = (logit[:, 1] > 0.5).float() #用0.5作为阈值
tumor_frac = patch_pred.sum().item() / len(patch_pred)
all_scores.append(tumor_frac)
# else: #classifier abandoned
# #cancer_logit = logit[:, 1].squeeze()
# cancer_logit = logit
# #打乱
# rand_idx = torch.randperm(len(cancer_logit))
# cancer_logit = cancer_logit[rand_idx, :]
# batch_size = batch_classifier.get_input_size()
# # 重采样凑成整数个batch
# batch_num = len(cancer_logit) // batch_size + 1
# num_missing = batch_num * batch_size - len(cancer_logit)
# cancer_logit = torch.cat([cancer_logit, cancer_logit[:num_missing, :]]).reshape(batch_num, batch_size, 2)
# with torch.no_grad():
# batch_prob = batch_classifier(cancer_logit)
# all_scores.append(torch.mean(batch_prob).item())
all_labels.append(wsi_label)
auc_roc = roc_auc_score(all_labels, all_scores)
return auc_roc
def dice_auc(logits, coords, mask_img, patch_mask, patch_size = 224):# patch level auc
assert len(logits) == len(coords), "Number of logits must match number of coordinates"
fpr, tpr, thds = roc_curve(patch_mask, logits[:, 1])
best_thd_idx = np.argmax(tpr - fpr)
thd = thds[best_thd_idx]
auc = roc_auc_score(patch_mask, logits[:, 1])
#print(len(logits))
mask_sum = np.count_nonzero(mask_img)*256
pred_sum = 0
intersection_sum = 0
pred_mask = np.zeros_like(mask_img)
for i in range(len(logits)):
logit = logits[i]
coord = coords[i]
if logit[1] > thd:
pred_sum += patch_size**2
pred_mask[int(coord[1]/16):int(coord[1]/16+patch_size/16), int(coord[0]/16):int(coord[0]/16+patch_size/16)] = 255
inter_mask = mask_img * pred_mask
intersection_sum = np.count_nonzero(inter_mask)*256
pred_sum = np.count_nonzero(pred_mask)*256
dice = 2*intersection_sum/(mask_sum + pred_sum) #dice越接近1,预测的越好
return [dice, auc]
def seg_dice_auc(model, test_loader, device, patch_size=224, batch_classifier=None):
dice_scores = []
auc_scores = []
for data, coords, mask_img, patch_mask in test_loader:
data = data.to(device).squeeze(0)
coords = coords.squeeze(0)
mask_img = mask_img.squeeze(0)
patch_mask = patch_mask.squeeze(0)
with torch.no_grad():
logits = model(data, ensemble=False)
### test ###
top10idx = torch.topk(logits[:, 1], k=10)[1]
top10idx = top10idx.cpu().numpy()
result = patch_mask[top10idx]
print(f'top10 exisit{result.tolist()}')
######
logits = logits.cpu().numpy()
dice_score, auc_score = dice_auc(logits, coords, mask_img, patch_mask, patch_size)
dice_scores.append(dice_score)
auc_scores.append(auc_score)
return [np.mean(dice_scores), np.mean(auc_scores)]
def sub_bacc_wf1(model, test_loader, device, normal_ext=False, batch_classifier=None):
model.eval()
if batch_classifier:
batch_classifier.eval()
gt_labels = []
pred_labels = []
# #test
# acc_lst = []
# for data_pack in test_loader:
# datas = data_pack.permute(1, 0, 2).to(device) #(n_cls, batch, dim)
# wsi_label = torch.arange(0, 3, device=device)
# with torch.no_grad():
# sims = torch.stack([model(d, ensemble=True) for d in datas]) #(n_cls, batch, n_cls) #没有normal的sim
# batch_size = 1024
# #classifier
# if datas.shape[1] == batch_size: #防止dataloader的最后剩下的数据小于batch_size,与classifier的size不匹配
# aug_sims = []
# for t in range(100):
# rand_idxs = torch.randperm(batch_size)
# aug_sims.append(sims[:, rand_idxs, :])
# aug_sims = torch.cat(aug_sims, dim=0)#(n_cls*aug_times, batch, n_cls)
# aug_wsi_label = wsi_label.repeat(100) #(n_cls*aug_times)
# # 打乱
# rand_idxs = torch.randperm(len(aug_wsi_label))
# aug_sims = aug_sims[rand_idxs, :, :]
# aug_wsi_label = aug_wsi_label[rand_idxs]
# #拉直
# cat_aug_sims = aug_sims.transpose(1, 2).reshape(aug_sims.shape[0], -1)#(n_cls, batch*n_cls)
# with torch.no_grad():
# wsi_probs = batch_classifier(cat_aug_sims).squeeze() #(n_cls, n_cls)
# #acc
# wsi_pred = torch.argmax(wsi_probs, dim=1)
# acc = (wsi_pred == aug_wsi_label).sum().item() / (len(aug_wsi_label))
# acc_lst.append(acc)
# pred_labels += wsi_pred.tolist()
# gt_labels += aug_wsi_label.tolist()
for data, wsi_label in test_loader:
data = data.to(device).squeeze(0) # Assuming each batch contains one WSI (one wsi, batch = 1)
wsi_label = wsi_label.item()
with torch.no_grad():
logit = model(data, ensemble=False) # (n_patch, n_cls) or (n_patch, n_cls+1) with normal
# if normal_ext:
# logit = logit[:, 1:] #没有normal的wsi
wsi_label -= 1 #testset 的label从1开始
gt_labels.append(wsi_label)
if batch_classifier == None:
patch_pred = torch.argmax(logit, dim=1)
class_counts = torch.bincount(patch_pred, minlength=logit.shape[1])
all_normal_flag = class_counts[1:].sum() == 0
max_val = class_counts[1:].max()
mask = (class_counts[1:] == max_val)
tumor_equal_flag = mask.sum() > 1
if all_normal_flag or tumor_equal_flag:
patch_pred = torch.argmax(logit[:,1:], dim=1)
class_counts = torch.bincount(patch_pred, minlength=logit.shape[1]-1)
patch_prob = class_counts/class_counts.sum()
else:
patch_prob = class_counts[1:]/class_counts[1:].sum()
wsi_pred = torch.argmax(patch_prob).item()
pred_labels.append(wsi_pred)
# else: #classifier abandoned
# n_cls = logit.shape[1]
# batch_size = int(batch_classifier.get_input_size() / n_cls) #假设classifier用的是铺平策略
# batch_num = len(logit) // batch_size + 1
# num_missing = batch_num * batch_size - len(logit)
# rand_idx = torch.randperm(len(logit))
# logit = logit[rand_idx, :] #打乱每个patch
# logit = torch.cat([logit, logit[:num_missing, :]], dim=0).view(-1,1024,3)
# logit = logit.transpose(1, 2).reshape(logit.shape[0], -1)
# #logit = torch.cat([logit, logit[:num_missing, :]], dim=0).transpose(0,1).reshape(-1, batch_size*n_cls) #(batch_num, batch_size*n_cls)
# #logit = torch.cat([logit, logit[:num_missing, :]], dim=0).reshape(-1, batch_size*n_cls)
# with torch.no_grad():
# batch_prob = batch_classifier(logit) #(batch_num, n_cls)
# wsi_prob = torch.mean(batch_prob, dim=0) #(n_cls)
# wsi_pred = torch.argmax(wsi_prob).item()
# pred_labels.append(wsi_pred)
# 计算平衡准确率和加权F1分数
# print(gt_labels, pred_labels)
print(confusion_matrix(gt_labels,pred_labels))
balanced_acc = balanced_accuracy_score(gt_labels, pred_labels)
report = classification_report(gt_labels, pred_labels, output_dict=True, zero_division=0)
weighted_f1 = report['weighted avg']['f1-score']
# print(gt_labels)
# print(pred_labels)
return [balanced_acc, weighted_f1]