-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrealtime_test.py
More file actions
106 lines (80 loc) · 3.75 KB
/
realtime_test.py
File metadata and controls
106 lines (80 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import cv2
import imageio.v3 as iio
import numpy as np
import torch.nn as nn
import torch
from fire import Fire
from tqdm import tqdm
from hydra.utils import instantiate
from typing import List, Optional, Union
from pytorch_toolbelt.utils import transfer_weights
from SiamABC_tracker import SiamABCTracker
from core.utils.hydra import load_hydra_config_from_path
from core.models.custom_bn import replace_layers
def load_model(
model: nn.Module, checkpoint_path: str, map_location: Optional[Union[int, str]] = None, strict: bool = True
) -> nn.Module:
map_location = f"cuda:{map_location}" if type(map_location) is int else map_location
checkpoint = torch.load(checkpoint_path, map_location=map_location)
state_dict = {
k.lstrip("module").lstrip("."): v for k, v in checkpoint.items() if k.startswith("module.")
}
if strict:
model.load_state_dict(state_dict, strict=True)
else:
transfer_weights(model, state_dict)
return model
def get_tracker(config, weights_path: str, lambda_tta: int = 0.1) -> SiamABCTracker:
model = instantiate(config["model"])
replace_layers(model.connect_model.cls_dw, lambda_tta, False)
replace_layers(model.connect_model.reg_dw, lambda_tta, False)
replace_layers(model.connect_model.bbox_tower, lambda_tta, False)
replace_layers(model.connect_model.cls_tower, lambda_tta, False)
print(model)
model = load_model(model, weights_path, strict=False).cuda().eval()
tracker: SiamABCTracker = instantiate(config["tracker"], model=model)
return tracker
def track(tracker: SiamABCTracker, frames: List[np.ndarray], initial_bbox: np.ndarray) -> List[np.ndarray]:
tracked_bboxes = [initial_bbox]
tracker.initialize(frames[0], initial_bbox)
for idx, frame in tqdm(enumerate(frames[1:])):
tracked_bbox,cls_score = tracker.update(frame)
tracked_bboxes.append(tracked_bbox)
return tracked_bboxes
def draw_bbox(image: np.ndarray, bbox: np.ndarray, width: int = 5) -> np.ndarray:
image = image.copy()
x, y, w, h = bbox
return cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), width)
def visualize(frames: List[np.ndarray], tracked_bboxes: List[np.ndarray]):
visualized_frames = []
for frame, bbox in zip(frames, tracked_bboxes):
visualized_frames.append(draw_bbox(frame, bbox))
return visualized_frames
import os
def main(
initial_bbox: List[int] = [416, 414, 61, 97],
video_path: str = "assets/penguin_in_fog.mp4",
output_path: str = "outputs/penguin_in_fog.mp4",
config_path: str = "core/config",
config_name: str = "SiamABC_tracker",
model_size: str = "S_Tiny",
weights_path: str = "assets/S_Tiny/model_S_Tiny_v1.pt",
):
config = load_hydra_config_from_path(config_path=config_path, config_name=config_name)
config["model"]["model_size"] = 'S' if model_size=="S_Tiny" else 'M'
tracker = get_tracker(config=config, weights_path=weights_path)
video, metadata = iio.imread(video_path), iio.immeta(video_path, exclude_applied=False)
initial_bbox = np.array(initial_bbox).astype(int)
tracked_bboxes = track(tracker, video, initial_bbox)
visualized_video = visualize(video, tracked_bboxes)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
iio.imwrite(output_path, visualized_video, fps=metadata["fps"])
head, tail = os.path.split(output_path)
bbox_dir = os.path.join(head,'bboxes')
if os.path.exists(bbox_dir) == False: os.makedirs(bbox_dir)
with open(os.path.join(bbox_dir,os.path.splitext(tail)[0]+'.txt'), 'w', encoding='utf-8') as f:
for i in tracked_bboxes:
f.write(f'{i[0]} {i[1]} {i[2]} {i[3]} \n')
if __name__ == '__main__':
Fire(main)