-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfps.py
More file actions
77 lines (63 loc) · 2.22 KB
/
fps.py
File metadata and controls
77 lines (63 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import time
import torch
from nets.crackformer2 import crackformer
from nets.Deepcrack import DeepCrack
from nets.HED import HED
from nets.SDDNet import SDDNet
from nets.RCF import RCF
from nets.STRNet import STRNet
from nets.zjd_unet_2plus_ori import * #要改
#from nets.zjd_unet_2plus import *
from nets.Unet_delinear import *
from nets.Unet_delinear_less import *
input=torch.randn(1,3,512,512).cuda()
# input=torch.randn(1,3,544,384).cuda()
#input=torch.randn(1,3,448,448).cuda()
#input=torch.randn(1,3,480,480).cuda()
#nput=torch.randn(1,3,480,320).cuda()
# model=DeepCrack().cuda()
# model=HED().cuda()
# model=SDDNet(3,1).cuda()
#model=STRNet(3,1).cuda()
model=zjd_unet_2plus().cuda()
# model =linDeform_Unetless().cuda()
# model=RCF().cuda()
# model=crackformer().cuda()
# total_time=0
# start=time.time()
# end=time.time()
# single_fps=1/(end-start)
# total_time+=end-start
# # fps=(i+1)/total_time
# out=model(input)
# print(single_fps,total_time)
model.eval()
torch.cuda.synchronize()
start=time.time()
for index,_ in enumerate(range(600)):
out=model(input)
end=time.time()
# if index%100==0:
# print(index/(end-start))
torch.cuda.synchronize()
end=time.time()
print(1/((end-start)/600))
# model.eval() # 进入eval模式(即关闭掉droout方法
# total_time = 0
# device="cuda"
# with torch.no_grad():
# # predict class
# input = input.to(device)
# torch.cuda.synchronize()
# time_start = time.time()
# output = model(input.to(device)) # 将图片通过model正向传播,得到输出,将输入进行压缩,
# # 将batch维度压缩掉,得到最终输出(out)
# torch.cuda.synchronize()
# time_end = time.time()
# # predict = torch.softmax(output, dim=0) # 经过softmax处理后,就变成概率分布的形式了
# # predict_cla = torch.argmax(predict).numpy() # 通过argmax方法,得到概率最大的处所对应的索引
# single_fps = 1 / (time_end - time_start)
# time_sum = (time_end - time_start) * 1000
# print_res = "time: {: .3f}ms single_fps: {: .3f}".format(time_sum,single_fps)
# print(print_res)
# predict[predict_cla] 打印类别名称以及他所对应的预测概率