Skip to content

Commit 3e12fd4

Browse files
KrishPatel130dm
andauthored
feat: add progress bar in record.py and visualize.py
* run `poetry lock --no-update` * add alive-progress via poetry and in code * add progress bar in visualization * add a check for MAX_EVENT = None * update the title for the Progress bAr (better for USer pov) * update the requirement.txt * ran ` black --line-length 80 <file>` on record.py and visualize.py * remove all progress bar from record * add tqdm progress bar in recrod.py * add tqdm for visualiztion * remove alive-progress * consistent tqdm api --add dynamic_cols: to enable adjustments when window is resized Order: --total -description --unit --Optional[bar_format] --colour --dynamic_ncols * Update requirements.txt Co-authored-by: Aaron <[email protected]> * Address comemnt: #318 (comment) * remove incorrect indent * remove rows * try to fix distorted table in html * add custom queue class * lint --line-length 80 * fix `NotImplementedError` for MacOs -- using custom MyQueue class * rename custom -> thirdparty_customization * rename to something useful * address comments * rename dir to customized_imports * rename to extensions #318 (comment) --------- Co-authored-by: Aaron <[email protected]>
1 parent d15f683 commit 3e12fd4

File tree

4 files changed

+248
-83
lines changed

4 files changed

+248
-83
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Module for customizing multiprocessing.Queue to avoid NotImplementedError in MacOS
3+
"""
4+
5+
6+
from multiprocessing.queues import Queue
7+
import multiprocessing
8+
9+
# Credit: https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9
10+
11+
# The following implementation of custom SynchronizedQueue to avoid NotImplementedError
12+
# when calling queue.qsize() in MacOS X comes almost entirely from this github
13+
# discussion: https://github.com/keras-team/autokeras/issues/368
14+
# Necessary modification is made to make the code compatible with Python3.
15+
16+
17+
class SharedCounter(object):
18+
""" A synchronized shared counter.
19+
The locking done by multiprocessing.Value ensures that only a single
20+
process or thread may read or write the in-memory ctypes object. However,
21+
in order to do n += 1, Python performs a read followed by a write, so a
22+
second process may read the old value before the new one is written by the
23+
first process. The solution is to use a multiprocessing.Lock to guarantee
24+
the atomicity of the modifications to Value.
25+
This class comes almost entirely from Eli Bendersky's blog:
26+
http://eli.thegreenplace.net/2012/01/04/
27+
shared-counter-with-pythons-multiprocessing/
28+
"""
29+
30+
def __init__(self, n=0):
31+
self.count = multiprocessing.Value('i', n)
32+
33+
def increment(self, n=1):
34+
""" Increment the counter by n (default = 1) """
35+
with self.count.get_lock():
36+
self.count.value += n
37+
38+
@property
39+
def value(self):
40+
""" Return the value of the counter """
41+
return self.count.value
42+
43+
44+
class SynchronizedQueue(Queue):
45+
""" A portable implementation of multiprocessing.Queue.
46+
Because of multithreading / multiprocessing semantics, Queue.qsize() may
47+
raise the NotImplementedError exception on Unix platforms like Mac OS X
48+
where sem_getvalue() is not implemented. This subclass addresses this
49+
problem by using a synchronized shared counter (initialized to zero) and
50+
increasing / decreasing its value every time the put() and get() methods
51+
are called, respectively. This not only prevents NotImplementedError from
52+
being raised, but also allows us to implement a reliable version of both
53+
qsize() and empty().
54+
Note the implementation of __getstate__ and __setstate__ which help to
55+
serialize SynchronizedQueue when it is passed between processes. If these functions
56+
are not defined, SynchronizedQueue cannot be serialized, which will lead to the error
57+
of "AttributeError: 'SynchronizedQueue' object has no attribute 'size'".
58+
See the answer provided here: https://stackoverflow.com/a/65513291/9723036
59+
60+
For documentation of using __getstate__ and __setstate__
61+
to serialize objects, refer to here:
62+
https://docs.python.org/3/library/pickle.html#pickling-class-instances
63+
"""
64+
65+
def __init__(self):
66+
super().__init__(ctx=multiprocessing.get_context())
67+
self.size = SharedCounter(0)
68+
69+
def __getstate__(self):
70+
"""Help to make SynchronizedQueue instance serializable.
71+
Note that we record the parent class state, which is the state of the
72+
actual queue, and the size of the queue, which is the state of SynchronizedQueue.
73+
self.size is a SharedCounter instance. It is itself serializable.
74+
"""
75+
return {
76+
'parent_state': super().__getstate__(),
77+
'size': self.size,
78+
}
79+
80+
def __setstate__(self, state):
81+
super().__setstate__(state['parent_state'])
82+
self.size = state['size']
83+
84+
def put(self, *args, **kwargs):
85+
super().put(*args, **kwargs)
86+
self.size.increment(1)
87+
88+
def get(self, *args, **kwargs):
89+
item = super().get(*args, **kwargs)
90+
self.size.increment(-1)
91+
return item
92+
93+
def qsize(self):
94+
""" Reliable implementation of multiprocessing.Queue.qsize() """
95+
return self.size.value
96+
97+
def empty(self):
98+
""" Reliable implementation of multiprocessing.Queue.empty() """
99+
return not self.qsize()

openadapt/record.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from loguru import logger
2222
from pympler import tracker
2323
from pynput import keyboard, mouse
24+
from tqdm import tqdm
2425
import fire
2526
import mss.tools
2627
import psutil
2728

2829
from openadapt import config, crud, utils, window
30+
from openadapt.extensions import synchronized_queue as sq
2931

3032
Event = namedtuple("Event", ("timestamp", "type", "data"))
3133

@@ -86,7 +88,9 @@ def wrapper_logging(*args, **kwargs):
8688
func_kwargs = kwargs_to_str(**kwargs)
8789

8890
if func_kwargs != "":
89-
logger.info(f" -> Enter: {func_name}({func_args}, {func_kwargs})")
91+
logger.info(
92+
f" -> Enter: {func_name}({func_args}, {func_kwargs})"
93+
)
9094
else:
9195
logger.info(f" -> Enter: {func_name}({func_args})")
9296

@@ -110,10 +114,10 @@ def process_event(event, write_q, write_fn, recording_timestamp, perf_q):
110114
@trace(logger)
111115
def process_events(
112116
event_q: queue.Queue,
113-
screen_write_q: multiprocessing.Queue,
114-
action_write_q: multiprocessing.Queue,
115-
window_write_q: multiprocessing.Queue,
116-
perf_q: multiprocessing.Queue,
117+
screen_write_q: sq.SynchronizedQueue,
118+
action_write_q: sq.SynchronizedQueue,
119+
window_write_q: sq.SynchronizedQueue,
120+
perf_q: sq.SynchronizedQueue,
117121
recording_timestamp: float,
118122
terminate_event: multiprocessing.Event,
119123
):
@@ -193,7 +197,7 @@ def process_events(
193197
def write_action_event(
194198
recording_timestamp: float,
195199
event: Event,
196-
perf_q: multiprocessing.Queue,
200+
perf_q: sq.SynchronizedQueue,
197201
):
198202
"""
199203
Write an action event to the database and update the performance queue.
@@ -212,7 +216,7 @@ def write_action_event(
212216
def write_screen_event(
213217
recording_timestamp: float,
214218
event: Event,
215-
perf_q: multiprocessing.Queue,
219+
perf_q: sq.SynchronizedQueue,
216220
):
217221
"""
218222
Write a screen event to the database and update the performance queue.
@@ -234,7 +238,7 @@ def write_screen_event(
234238
def write_window_event(
235239
recording_timestamp: float,
236240
event: Event,
237-
perf_q: multiprocessing.Queue,
241+
perf_q: sq.SynchronizedQueue,
238242
):
239243
"""
240244
Write a window event to the database and update the performance queue.
@@ -254,10 +258,11 @@ def write_window_event(
254258
def write_events(
255259
event_type: str,
256260
write_fn: Callable,
257-
write_q: multiprocessing.Queue,
258-
perf_q: multiprocessing.Queue,
261+
write_q: sq.SynchronizedQueue,
262+
perf_q: sq.SynchronizedQueue,
259263
recording_timestamp: float,
260264
terminate_event: multiprocessing.Event,
265+
term_pipe: multiprocessing.Pipe,
261266
):
262267
"""
263268
Write events of a specific type to the db using the provided write function.
@@ -269,20 +274,48 @@ def write_events(
269274
perf_q: A queue for collecting performance data.
270275
recording_timestamp: The timestamp of the recording.
271276
terminate_event: An event to signal the termination of the process.
277+
term_pipe: A pipe for communicating \
278+
the number of events left to be written.
272279
"""
273280

274281
utils.configure_logging(logger, LOG_LEVEL)
275282
utils.set_start_time(recording_timestamp)
276283
logger.info(f"{event_type=} starting")
277284
signal.signal(signal.SIGINT, signal.SIG_IGN)
278-
while not terminate_event.is_set() or not write_q.empty():
285+
286+
num_left = 0
287+
progress = None
288+
while (
289+
not terminate_event.is_set() or
290+
not write_q.empty()
291+
):
292+
if term_pipe.poll():
293+
num_left = term_pipe.recv()
294+
if num_left != 0 and progress is None:
295+
progress = tqdm(
296+
total=num_left,
297+
desc="Writing to Database",
298+
unit="event",
299+
colour="green",
300+
dynamic_ncols=True,
301+
)
302+
if (
303+
terminate_event.is_set() and
304+
num_left != 0 and
305+
progress is not None
306+
):
307+
progress.update()
279308
try:
280309
event = write_q.get_nowait()
281310
except queue.Empty:
282311
continue
283312
assert event.type == event_type, (event_type, event)
284313
write_fn(recording_timestamp, event, perf_q)
285314
logger.debug(f"{event_type=} written")
315+
316+
if progress is not None:
317+
progress.close()
318+
286319
logger.info(f"{event_type=} done")
287320

288321

@@ -375,15 +408,18 @@ def handle_key(
375408
"vk",
376409
]
377410
attrs = {
378-
f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names
411+
f"key_{attr_name}": getattr(key, attr_name, None)
412+
for attr_name in attr_names
379413
}
380414
logger.debug(f"{attrs=}")
381415
canonical_attrs = {
382416
f"canonical_key_{attr_name}": getattr(canonical_key, attr_name, None)
383417
for attr_name in attr_names
384418
}
385419
logger.debug(f"{canonical_attrs=}")
386-
trigger_action_event(event_q, {"name": event_name, **attrs, **canonical_attrs})
420+
trigger_action_event(
421+
event_q, {"name": event_name, **attrs, **canonical_attrs}
422+
)
387423

388424

389425
def read_screen_events(
@@ -463,7 +499,7 @@ def read_window_events(
463499

464500
@trace(logger)
465501
def performance_stats_writer(
466-
perf_q: multiprocessing.Queue,
502+
perf_q: sq.SynchronizedQueue,
467503
recording_timestamp: float,
468504
terminate_event: multiprocessing.Event,
469505
):
@@ -660,13 +696,17 @@ def record(
660696
recording_timestamp = recording.timestamp
661697

662698
event_q = queue.Queue()
663-
screen_write_q = multiprocessing.Queue()
664-
action_write_q = multiprocessing.Queue()
665-
window_write_q = multiprocessing.Queue()
699+
screen_write_q = sq.SynchronizedQueue()
700+
action_write_q = sq.SynchronizedQueue()
701+
window_write_q = sq.SynchronizedQueue()
666702
# TODO: save write times to DB; display performance plot in visualize.py
667-
perf_q = multiprocessing.Queue()
703+
perf_q = sq.SynchronizedQueue()
668704
terminate_event = multiprocessing.Event()
669-
705+
706+
term_pipe_parent_window, term_pipe_child_window = multiprocessing.Pipe()
707+
term_pipe_parent_screen, term_pipe_child_screen = multiprocessing.Pipe()
708+
term_pipe_parent_action, term_pipe_child_action = multiprocessing.Pipe()
709+
670710
window_event_reader = threading.Thread(
671711
target=read_window_events,
672712
args=(event_q, terminate_event, recording_timestamp),
@@ -714,6 +754,7 @@ def record(
714754
perf_q,
715755
recording_timestamp,
716756
terminate_event,
757+
term_pipe_child_screen,
717758
),
718759
)
719760
screen_event_writer.start()
@@ -727,6 +768,7 @@ def record(
727768
perf_q,
728769
recording_timestamp,
729770
terminate_event,
771+
term_pipe_child_action,
730772
),
731773
)
732774
action_event_writer.start()
@@ -740,6 +782,7 @@ def record(
740782
perf_q,
741783
recording_timestamp,
742784
terminate_event,
785+
term_pipe_child_window,
743786
),
744787
)
745788
window_event_writer.start()
@@ -776,9 +819,14 @@ def record(
776819
except KeyboardInterrupt:
777820
terminate_event.set()
778821

822+
779823
collect_stats()
780824
log_memory_usage()
781825

826+
term_pipe_parent_window.send(window_write_q.qsize())
827+
term_pipe_parent_action.send(action_write_q.qsize())
828+
term_pipe_parent_screen.send(screen_write_q.qsize())
829+
782830
logger.info(f"joining...")
783831
keyboard_event_reader.join()
784832
mouse_event_reader.join()
@@ -788,7 +836,6 @@ def record(
788836
screen_event_writer.join()
789837
action_event_writer.join()
790838
window_event_writer.join()
791-
792839
terminate_perf_event.set()
793840

794841
if PLOT_PERFORMANCE:

openadapt/scripts/scrub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def scrub_mp4(
113113
unit="frame",
114114
bar_format=progress_bar_format,
115115
colour="green",
116+
dynamic_ncols=True,
116117
)
117118
progress_interval = 0.1 # Print progress every 10% of frames
118119
progress_threshold = math.floor(frame_count * progress_interval)

0 commit comments

Comments
 (0)