-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
374 lines (284 loc) · 15.7 KB
/
loss.py
File metadata and controls
374 lines (284 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import numpy as np
import torch
import torch.nn.functional as F
# 对于classification,0为normal,1为cancer
# 对于subtyping,0为type1,1为type2,2为type3 Not implemented yet
def smooth_func(tensor, tau=0.1):
return torch.log(1 + torch.exp(tensor / tau))
def soft_top1(tensor, dim=1, tau=0.1): # log(e^5 + e^3 + e^-1) = 5
return torch.log(torch.sum(torch.exp(tensor / tau), dim=dim)) * tau
# abandoned
class BinaryLoss(torch.nn.Module):
def __init__(self, tau=0.1, init_mu=0.1):
super(BinaryLoss, self).__init__()
self.tau = tau
self.mu = init_mu
def forward(self, sims):
sim0 = sims[0]
sim1 = sims[1]
### patch loss ###
patch_loss0 = torch.mean(smooth_func(sim0[:, 1] - sim0[:, 0], tau=self.tau))
patch_label1 = torch.argmax(sim1, dim=1)
mask1 = patch_label1 != 0
cancer_ratio = torch.sum(mask1) / mask1.numel()
_, mask2 = torch.topk(sim1[:, 1], k=int(sim1.shape[0] * self.mu))
cancer_sim1 = sim1[mask1]
### normal loss ###
# # Cond.1 始终惩罚 topMu 的 sim[:,0]
# if cancer_ratio == 0:
# cancer_sim1 = sim1[mask2]
# normal_loss = torch.mean(smooth_func(sim1[mask2, 0], tau=self.tau))
# Cond.2 仅当 r == 0 时惩罚 sim[:,0]
if cancer_ratio == 0:
cancer_sim1 = sim1[mask2]
normal_loss = torch.mean(smooth_func(sim1[mask2, 0], tau=self.tau))
else:
normal_loss = torch.zeros(1, device=sim1.device).squeeze()
# # Cond.3 仅当 r < Mu 时惩罚 sim[:,0]
# if cancer_ratio == 0:
# cancer_sim1 = sim1[mask2]
# if cancer_ratio < self.mu:
# normal_loss = torch.mean(smooth_func(sim1[mask2, 0], tau=self.tau))
# else:
# normal_loss = torch.zeros(1, device=sim1.device).squeeze()
# # Cond.4 当 r < Mu 时惩罚 sim[:,0],当 r > Mu 时惩罚 topMu 补集的 sim[:,1]
# if cancer_ratio == 0:
# cancer_sim1 = sim1[mask2]
# if cancer_ratio < self.mu:
# normal_loss = torch.mean(smooth_func(sim1[mask2, 0], tau=self.tau))
# else:
# normal_loss = torch.mean(smooth_func(sim1[~mask2, 1], tau=self.tau))
### patch loss ###
patch_loss1 = torch.mean(smooth_func(cancer_sim1[:, 0] - cancer_sim1[:, 1], tau=self.tau))
### contrast loss ###
patch_label0 = torch.argmax(sim0, dim=1)
mask0 = patch_label0 != 0
cancer_sim0 = sim0[mask0]
if cancer_sim0.numel() == 0 and cancer_sim1.numel() == 0:
contrast_loss = torch.zeros(1, device=sim1.device).squeeze()
else:
contrast_loss = smooth_func((cancer_sim0[:, 1].sum() - cancer_sim1[:, 1].sum()) / max(len(cancer_sim0), len(cancer_sim1)), tau=self.tau)
### 更新mu ###
self.mu += 0.001 * (cancer_ratio - self.mu)
# mu限制
self.mu = torch.clamp(self.mu, min=0.001, max=0.999)
return patch_loss0 + patch_loss1 + contrast_loss + normal_loss, torch.stack([patch_loss0, patch_loss1, contrast_loss, normal_loss, cancer_ratio, self.mu]).squeeze()
# abandoned
class BinaryLossTopMu(torch.nn.Module):
def __init__(self, tau=0.1, init_mu=0.1):
super(BinaryLossTopMu, self).__init__()
self.tau = tau
self.mu = init_mu
def forward(self, sims, update_mu=True):
sim0 = sims[0]
sim1 = sims[1]
assert sim0.shape == sim1.shape
self.mu = self.mu if torch.is_tensor(self.mu) else torch.tensor(self.mu, device=sim0.device)
k = int(torch.ceil(sim0.shape[0] * self.mu).item())
### patch loss ###
patch_loss0 = torch.mean(smooth_func(sim0[:, 1] - sim0[:, 0], tau=self.tau))
patch_label1 = torch.argmax(sim1, dim=1)
mask1 = patch_label1 != 0
cancer_ratio = torch.sum(mask1) / mask1.numel()
_, mask2 = torch.topk(sim1[:, 1], k)
mask3 = ~mask2
patch_loss1_cancer = smooth_func(sim1[mask2, 0] - sim1[mask2, 1], tau=self.tau)
patch_loss1_normal = smooth_func(sim1[mask3, 1] - sim1[mask3, 0], tau=self.tau)
patch_loss1 = torch.mean(torch.stack([patch_loss1_cancer, patch_loss1_normal]))
### contrast loss ###
_, mask4 = torch.topk(sim0[:, 1], k)
contrast_loss = smooth_func(sim0[mask4, 1].mean() - sim1[mask2, 1].mean(), tau=self.tau)
### 更新mu ###
if update_mu:
self.mu += 0.001 * (cancer_ratio - self.mu)
# mu限制
self.mu = torch.clamp(self.mu, min=0, max=1)
return patch_loss0 + patch_loss1 + contrast_loss, torch.stack([patch_loss0, patch_loss1, contrast_loss, cancer_ratio, self.mu]).squeeze()
class BinaryLossSoftTop1(torch.nn.Module):
def __init__(self, tau=0.1):
super(BinaryLossSoftTop1, self).__init__()
self.tau = tau
def forward(self, sims):
sim0 = sims[0]
sim1 = sims[1]
soft_top1_normal = soft_top1(sim0[:, 1], dim=0, tau=self.tau)
soft_top1_cancer = soft_top1(sim1[:, 1], dim=0, tau=self.tau)
contrast_loss = smooth_func(soft_top1_normal - soft_top1_cancer, tau=self.tau)
return contrast_loss
def delete_normal_patch(matrix, label, mu=0.1):
max_indices = torch.argmax(matrix, dim=1) # 每行最大值的列索引
# 筛选出第一列不是最大值的行
mask = max_indices != 0 # 生成布尔掩码
cancer_ratio = torch.sum(mask) / mask.numel()
# 应用掩码保留符合条件的行
filtered_matrix = matrix[mask]
_, max_indices = torch.topk(matrix[:, label+1], k=int(matrix.shape[0]*mu))
topk_matrix = matrix[max_indices]
if filtered_matrix.numel() == 0:
filtered_matrix = topk_matrix
return filtered_matrix, topk_matrix, cancer_ratio
#anbandoned
class MultiLoss(torch.nn.Module):
def __init__(self, tau=0.1, init_mu=0.3):
super(MultiLoss, self).__init__()
self.tau = tau
self.mu = init_mu
def forward(self, sims): #(n_cls, batch, n_cls)
batch_num, cls_num = sims[0].shape
tumor_num = cls_num-1
device = sims[0].device
tumor_patch_num = []
ratio_lst = []
### patch loss ###
# 因为假定了一张wsi只有一种subtype的patch,所以multi patch loss里的方案1不生效
patch_losses = torch.zeros(tumor_num).to(device)
batch_sum_sim = torch.zeros(tumor_num, tumor_num).to(device)
normal_loss = torch.zeros(1).to(device).squeeze() #当model判断所有的patch都是normal时才会生效
for i in range(tumor_num):
patches_sim, topk_sim, cancer_ratio = delete_normal_patch(sims[i], i, mu=self.mu)
ratio_lst.append(cancer_ratio)
### normal loss ###
if cancer_ratio == 0:
normal_loss += torch.mean(smooth_func(topk_sim[:, 0], tau=self.tau))
tumor_patches_sim = patches_sim[:, 1:]
n_column = tumor_patches_sim[:, i].unsqueeze(1) # 提取第 n 列并增加维度以支持广播 # 形状为 [n, 1]
# 每列减去第 n 列
sim_diff = tumor_patches_sim - n_column
sim_diff[:,i] = -2 #给了e的-2次幂,约等于0
## loss for each patch
patch_loss = smooth_func(sim_diff, tau=self.tau).sum()/(sim_diff.shape[0]*(sim_diff.shape[1]-1))
patch_losses[i] = patch_loss
## prepare sum sim for each WSI
batch_sum_sim[i,:] = torch.sum(tumor_patches_sim, dim=0)
tumor_patch_num.append(tumor_patches_sim.shape[0])
### contrast loss ###
## normalize batch_sum_sim by the largest patch_num
batch_sum_sim = batch_sum_sim/max(tumor_patch_num)
diag_elements = torch.diag(batch_sum_sim) # 提取对角线元素,形状为 [n]
# 将对角元素扩展为列向量,方便后续广播
diag_column = diag_elements.unsqueeze(0) # 形状为 [1, n]
# 替换所有非对角元素
pos_ext = diag_column.repeat(batch_sum_sim.size(0), 1) # 广播为整个矩阵
# [[a,x,x],[x,b,x],[x,x,c]] -> [[a,b,c],[a,b,c],[a,b,c]]
batch_sim_diff = batch_sum_sim - pos_ext
batch_sim_diff[batch_sim_diff==0] = -2
contrast_loss = smooth_func(batch_sim_diff, tau=self.tau).sum()/(batch_sim_diff.shape[0]*(batch_sim_diff.shape[1]-1))
loss = contrast_loss + torch.mean(patch_losses) + normal_loss
#print(f'patch_losses: {patch_losses}, cross_loss: {cross_loss}')
#assert not torch.isnan(patch_losses).any()
### 更新mu ###
cancer_ratio = sum(ratio_lst) / len(ratio_lst)
self.mu += 0.1 * (cancer_ratio - self.mu)
self.mu = torch.clamp(self.mu, min=0.3, max=0.7)
#为了与binary loss一致,要返回一个四维的dist
return loss, torch.stack([torch.mean(patch_losses), contrast_loss, normal_loss, torch.zeros(1, device=device).squeeze(), cancer_ratio, self.mu]).squeeze()
class MultiLossSoftTop1(torch.nn.Module): #TODO:考虑用soft top1代替soft top1矩阵中的逐位相减
def __init__(self, tau=0.1):
super(MultiLossSoftTop1, self).__init__()
self.tau = tau
def forward(self, sims): #(n_cls, batch, n_cls)不含normal
soft_top1_sims = soft_top1(sims, dim=1, tau=self.tau) #(n_cls, n_cls)
diag_elements = torch.diag(soft_top1_sims) # 提取对角线元素,形状为 (n_cls,)
diag_column = diag_elements.unsqueeze(1) # 形状为 (1, n_cls)
pos_ext_column = diag_column.repeat(1, soft_top1_sims.shape[1]) # 广播为整个矩阵
# [[a,x,x],[x,b,x],[x,x,c]] -> [[a,a,a],[b,b,b],[c,c,c]]
soft_top1_diff_column = soft_top1_sims - pos_ext_column #(n_cls, n_cls)
soft_top1_diff_column[soft_top1_diff_column==0] = -2
diag_row = diag_elements.unsqueeze(0) # 形状为 (n_cls, 1)
pos_ext_row = diag_row.repeat(soft_top1_sims.shape[1], 1) # 广播为整个矩阵
# [[a,x,x],[x,b,x],[x,x,c]] -> [[a,b,c],[a,b,c],[a,b,c]]
soft_top1_diff_row = soft_top1_sims - pos_ext_row #(n_cls, n_cls)
soft_top1_diff_row[soft_top1_diff_row==0] = -2
contrast_loss = smooth_func(soft_top1_diff_column, tau=self.tau).sum() + smooth_func(soft_top1_diff_row, tau=self.tau).sum()
return contrast_loss
class GroupConLoss(torch.nn.Module):
def __init__(self, tau=0.1):
super(GroupConLoss, self).__init__()
self.tau = tau
def forward(self, sims, pesudo_labels = None): #(n_cls, batch, n_cls)不含normal
## pesudo_loss
sim_logits = F.softmax(sims, dim=2)
# group_masks = torch.tensor([
# [1, 1, 0, 0], # 组 A 的合法类别
# [1, 0, 1, 0], # 组 B 的合法类别
# [1, 0, 0, 1], # 组 C 的合法类别
# ]).to(sims.device) # 每组的合法类别掩码
group_masks = torch.zeros(sim_logits.shape[2]-1, sim_logits.shape[2], device=sims.device)
group_masks[:, 0] = 1 # 所有组第一个类别合法
for i in range(sim_logits.shape[2]-1):
group_masks[i, i + 1] = 1 # 第i组的第二个合法类别
num_groups, num_images, num_classes = sim_logits.size() # (3, 1024, 4)
# 将 logits 展平,合并所有组:形状 (3 * 1024, 4)
logits = sim_logits.permute(1,0,2)
logits = logits.reshape(-1, num_classes)
# 计算合法类别的 softmax 概率
# probs = F.softmax(logits, dim=1) # (3 * 1024, 4)
group_masks = group_masks.unsqueeze(1)
group_masks = group_masks.expand(-1, num_images, -1) # (3, 1024, 4)
group_masks = group_masks.permute(1, 0, 2) # (1024, 3, 4)
group_masks = group_masks.reshape(-1, num_classes) # (1024 * 3, 4)
# 如果没有伪标签,惩罚非法类别的概率
labeled_probs = logits * group_masks # 非法类别的概率
# if not pesudo_labels:
labels = torch.argmax(labeled_probs[:,1:], dim=1).detach()
grouped_probs = labeled_probs.view(-1, labeled_probs.shape[1]-1, labeled_probs.shape[1])
first_cols = grouped_probs[..., 0] # [组数, 3, 1]
sub_matrix = grouped_probs[..., 1:]
first_cols_expanded = first_cols.unsqueeze(2) # 形状 [N, 3, 1]
first_cols_expanded = first_cols_expanded.expand(-1, -1, 3) # 形状 [N, 3, 3]
diag_mask = torch.eye(labeled_probs.shape[1]-1, dtype=torch.bool, device=labeled_probs.device) # [3,3]
# 执行对角线运算(仅修改对角线元素)
modified_logits = sub_matrix.clone()
# tmp = (sub_matrix[:, diag_mask] + first_cols_expanded[:, diag_mask]) / 2
modified_logits[:, diag_mask] = (sub_matrix[:, diag_mask] + first_cols_expanded[:, diag_mask]) / 2
# modified_logits[:, diag_mask] = torch.minimum(
# sub_matrix[:, diag_mask],
# first_cols_expanded[:, diag_mask]
# )
modified_logits = modified_logits.view(-1, 3)
false_logits = logits*(1-group_masks)
final_logits = false_logits[:,1:] + modified_logits
# torch.sum(labeled_probs, dim = 1)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1, reduction='sum')
loss_weak = loss_fn(final_logits, labels)/num_groups/num_images
# loss_strong = abs(false_logits).sum()
cls_num, patch_num, cols = sim_logits.shape
mask = torch.zeros(cls_num, patch_num, cols).to(sim_logits.device)
mask[:, :, 0] = 1
for batch_idx in range(cls_num):
mask[batch_idx, :, batch_idx+1] = 1
logits_pos = sim_logits*mask
loss_pos = 0
lambda_coeff = 1.0 # 约束损失的权重
for i in range(cls_num):
col0 = logits_pos[i,:,0]
colX = logits_pos[i,:,i + 1]
# 计算每个样本的差值
diff = colX - col0
# 获取最大差值
max_diff = torch.max(diff)
# 计算约束损失
constraint_loss = torch.max(torch.tensor(0.0), - max_diff)
# constraint_loss = torch.log(1 + torch.exp((- max_diff)/self.tau))
# 总损失(假设主损失为 main_loss)
loss_pos += lambda_coeff * constraint_loss
loss_pos = loss_pos/cls_num
loss = loss_pos + loss_weak
# true_logits = probs*group_masks
# true_logits[true_logits==0] = -2
# loss = loss_weak
# illegal_loss = torch.log(1 + illegal_probs).mean() # 非法类别概率的平均值作为惩罚
return loss
if __name__ == '__main__':
# # Generate two test simularity arrays with shape [16, 2]
# similarity0 = torch.rand([16, 2])
# similarity1 = torch.rand([16, 2])
# WSI_label = 0
# patch_loss_value = patch_loss(similarity0, WSI_label)
# contrast_loss_value = contrast_loss(similarity0, similarity1)
# print(f"Patch Loss: {patch_loss_value}")
# print(f"Contrast Loss: {contrast_loss_value}")
# test multi_loss
sims = [torch.rand([17, 4]), torch.rand([17, 4]), torch.rand([17, 4])]
loss = MultiLoss()
print(loss(sims))