forked from edusense/ClassID
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFaceWrapper.py
More file actions
190 lines (163 loc) · 7.39 KB
/
FaceWrapper.py
File metadata and controls
190 lines (163 loc) · 7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from __future__ import print_function
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import sys
import os
# sys.path.insert(0, "./retinaface/")
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from retinaface.data import cfg_mnet, cfg_re50
from retinaface.layers.functions.prior_box import PriorBox
from retinaface.utils.nms.py_cpu_nms import py_cpu_nms
from retinaface.models.retinaface import RetinaFace
from retinaface.utils.box_utils import decode, decode_landm
from retinaface.utils.timer import Timer
import cv2
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
# print('Missing keys:{}'.format(len(missing_keys)))
# print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
# print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
# print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu, device):
# print('Loading pretrained model from {}'.format(pretrained_path))
# if load_to_cpu or device == 'cpu':
# pretrained_dict = torch.load(pretrained_path, map_location='cpu')
# elif device.startswith('cuda'):
# device_id = int(device.split(':')[1]) if ':' in device else 0
# pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device_id))
# elif device == 'mps':
# pretrained_dict = torch.load(pretrained_path, map_location='cpu') # Load to CPU first, then move to MPS
# else:
pretrained_dict = torch.load(pretrained_path, map_location='cpu')
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
class RetinaFaceInference(object):
def __init__(self, threshold=0.5, network="mobile0.25", device='cpu'):
# Auto-detect device if not specified
if device is None:
if torch.backends.mps.is_available():
device = 'mps' # Apple Silicon GPU
elif torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu' # Fallback to CPU
torch.set_grad_enabled(False)
cfg = None
if network == "mobile0.25":
cfg = cfg_mnet
elif network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase='test')
# Determine if we should load to CPU
load_to_cpu = (device == 'cpu')
net = load_model(net, "weights/mobilenet0.25_Final.pth", load_to_cpu, device)
# net = load_model(net, "weights/mobilenet0.25_Final.pth", load_to_cpu, device)
net.eval()
# print('Finished loading model!')
# print(net)
# # Only set CUDNN benchmark for CUDA devices
# if device.startswith('cuda'):
# cudnn.benchmark = True
self.device = device
net = net.to(self.device)
self.net = net
torch.set_grad_enabled(False)
self._t = {'forward_pass': Timer(), 'misc': Timer()}
self.cfg = cfg
self.threshold = threshold
def run(self, img, frame_debug=None):
img = np.float32(img)
im_height, im_width, _ = img.shape
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(self.device)
scale = scale.to(self.device)
try:
self._t['forward_pass'].tic()
loc, conf, landms = self.net(img) # forward pass
self._t['forward_pass'].toc()
self._t['misc'].tic()
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(self.device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
boxes = boxes * scale / 1
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(self.device)
landms = landms * scale1 / 1
landms = landms.cpu().numpy()
confidence_threshold = 0.02
# ignore low scores
inds = np.where(scores > confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
# order = scores.argsort()[::-1][:args.top_k]
order = scores.argsort()[::-1]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
nms_threshold = 0.4
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, nms_threshold)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
# dets = dets[:args.keep_top_k, :]
# landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
self._t['misc'].toc()
conf = dets[:, 4]
filtered_idx = np.where(conf > self.threshold)
dets = dets[filtered_idx[0]]
if frame_debug is not None:
frame_debug = self.debug(dets, frame_debug)
else:
frame_debug = img
return dets, frame_debug
except Exception as e:
# print(f"Error in RetinaFace inference: {e}")
return np.array([]), None
def debug(self, dets, image):
for b in dets:
text = "{:.4f}".format(b[4])
b = list(map(int, b))
cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
cx = b[0]
cy = b[1] + 12
cv2.putText(image, text, (cx, cy),
cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))
# landms
cv2.circle(image, (b[5], b[6]), 1, (0, 0, 255), 4)
cv2.circle(image, (b[7], b[8]), 1, (0, 255, 255), 4)
cv2.circle(image, (b[9], b[10]), 1, (255, 0, 255), 4)
cv2.circle(image, (b[11], b[12]), 1, (0, 255, 0), 4)
cv2.circle(image, (b[13], b[14]), 1, (255, 0, 0), 4)
return image