-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel_Test.py
More file actions
68 lines (50 loc) · 1.94 KB
/
Model_Test.py
File metadata and controls
68 lines (50 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import cv2 as cv
import numpy as np
from numpy.lib.function_base import disp
import tensorflow as tf
import matplotlib.pyplot as plt
cap = cv.VideoCapture(0)
model_path = r'./checkpoints/age_range_model2_classification_relu.epoch16-loss1.78.hdf5'
other_path = r'./checkpoints/my_best_model.epoch15-loss106.01.hdf5'
# model_path = other_path
model = tf.keras.models.load_model(model_path)
face_cascade = cv.CascadeClassifier('./cascade/haarcascade_frontalface.xml');
is_age_range = "age_range" in model_path
is_classification = "classification" in model_path
count = 0
sum = 0
while True:
ret, frame = cap.read()
gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
for (x, y, w, h) in faces:
cv.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
display_text = ""
if (len(faces) >= 1):
(x, y, w, h) = faces[0]
cropped = np.array([row[x + 30: x + w - 30] for row in gray[y + 5: y + h - 5]])
small = cv.resize(cropped, (48, 48))
img = np.asarray(small).reshape((48,48,1))
cv.imshow('small ', img)
pred = model.predict(np.array([img]))[0]
if is_classification:
pred = np.array(tf.math.argmax(pred))
else:
pred = pred[0]
if pred != 0:
count += 1
if is_age_range:
sum += 5 * pred
display_text = "predicted age range is: " + str(round(pred * 5)) + " to " + str((round(pred + 1) * 5))
print(display_text)
else:
sum += pred
display_text = "Predicted age: " + str(round(pred))
print(display_text)
frame = cv.putText(frame, display_text, faces[0][:2], cv.FONT_HERSHEY_COMPLEX, 0.5, (255,0,0), 1)
cv.imshow('frame', frame)
if cv.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv.destroyAllWindows()
print("Avg age: ", sum // count)