-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_base.py
More file actions
47 lines (39 loc) · 1.19 KB
/
model_base.py
File metadata and controls
47 lines (39 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from abc import ABC, abstractmethod
from typing import List
from timm.models.layers import trunc_normal_
from torch import Tensor, einsum, nn
class BaseModel(nn.Module, ABC):
def _init_weights(self, m):
if isinstance(m, (nn.Linear)):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, (nn.Linear)) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
@staticmethod
def diffquick_or_papsmear(
filename: str,
keywords: List[str]
) -> bool:
for keyword in keywords:
if keyword in filename:
return True
return False
@staticmethod
def get_cam_1d(classifier: nn.Module,
features: Tensor
) -> Tensor:
t_weight = list(
classifier.parameters()
)
final_weight = t_weight[-2]
cam_maps = einsum(
'gf, cf -> cg',
features,
final_weight
)
return cam_maps
@abstractmethod
def forward(self):
pass