diff --git a/openadapt/extensions/synchronized_queue.py b/openadapt/extensions/synchronized_queue.py new file mode 100644 index 000000000..0c5e0da66 --- /dev/null +++ b/openadapt/extensions/synchronized_queue.py @@ -0,0 +1,99 @@ +""" + Module for customizing multiprocessing.Queue to avoid NotImplementedError in MacOS +""" + + +from multiprocessing.queues import Queue +import multiprocessing + +# Credit: https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9 + +# The following implementation of custom SynchronizedQueue to avoid NotImplementedError +# when calling queue.qsize() in MacOS X comes almost entirely from this github +# discussion: https://github.com/keras-team/autokeras/issues/368 +# Necessary modification is made to make the code compatible with Python3. + + +class SharedCounter(object): + """ A synchronized shared counter. + The locking done by multiprocessing.Value ensures that only a single + process or thread may read or write the in-memory ctypes object. However, + in order to do n += 1, Python performs a read followed by a write, so a + second process may read the old value before the new one is written by the + first process. The solution is to use a multiprocessing.Lock to guarantee + the atomicity of the modifications to Value. + This class comes almost entirely from Eli Bendersky's blog: + http://eli.thegreenplace.net/2012/01/04/ + shared-counter-with-pythons-multiprocessing/ + """ + + def __init__(self, n=0): + self.count = multiprocessing.Value('i', n) + + def increment(self, n=1): + """ Increment the counter by n (default = 1) """ + with self.count.get_lock(): + self.count.value += n + + @property + def value(self): + """ Return the value of the counter """ + return self.count.value + + +class SynchronizedQueue(Queue): + """ A portable implementation of multiprocessing.Queue. + Because of multithreading / multiprocessing semantics, Queue.qsize() may + raise the NotImplementedError exception on Unix platforms like Mac OS X + where sem_getvalue() is not implemented. This subclass addresses this + problem by using a synchronized shared counter (initialized to zero) and + increasing / decreasing its value every time the put() and get() methods + are called, respectively. This not only prevents NotImplementedError from + being raised, but also allows us to implement a reliable version of both + qsize() and empty(). + Note the implementation of __getstate__ and __setstate__ which help to + serialize SynchronizedQueue when it is passed between processes. If these functions + are not defined, SynchronizedQueue cannot be serialized, which will lead to the error + of "AttributeError: 'SynchronizedQueue' object has no attribute 'size'". + See the answer provided here: https://stackoverflow.com/a/65513291/9723036 + + For documentation of using __getstate__ and __setstate__ + to serialize objects, refer to here: + https://docs.python.org/3/library/pickle.html#pickling-class-instances + """ + + def __init__(self): + super().__init__(ctx=multiprocessing.get_context()) + self.size = SharedCounter(0) + + def __getstate__(self): + """Help to make SynchronizedQueue instance serializable. + Note that we record the parent class state, which is the state of the + actual queue, and the size of the queue, which is the state of SynchronizedQueue. + self.size is a SharedCounter instance. It is itself serializable. + """ + return { + 'parent_state': super().__getstate__(), + 'size': self.size, + } + + def __setstate__(self, state): + super().__setstate__(state['parent_state']) + self.size = state['size'] + + def put(self, *args, **kwargs): + super().put(*args, **kwargs) + self.size.increment(1) + + def get(self, *args, **kwargs): + item = super().get(*args, **kwargs) + self.size.increment(-1) + return item + + def qsize(self): + """ Reliable implementation of multiprocessing.Queue.qsize() """ + return self.size.value + + def empty(self): + """ Reliable implementation of multiprocessing.Queue.empty() """ + return not self.qsize() diff --git a/openadapt/record.py b/openadapt/record.py index 3d92c4c0c..1668e205d 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -21,11 +21,13 @@ from loguru import logger from pympler import tracker from pynput import keyboard, mouse +from tqdm import tqdm import fire import mss.tools import psutil from openadapt import config, crud, utils, window +from openadapt.extensions import synchronized_queue as sq Event = namedtuple("Event", ("timestamp", "type", "data")) @@ -86,7 +88,9 @@ def wrapper_logging(*args, **kwargs): func_kwargs = kwargs_to_str(**kwargs) if func_kwargs != "": - logger.info(f" -> Enter: {func_name}({func_args}, {func_kwargs})") + logger.info( + f" -> Enter: {func_name}({func_args}, {func_kwargs})" + ) else: logger.info(f" -> Enter: {func_name}({func_args})") @@ -110,10 +114,10 @@ def process_event(event, write_q, write_fn, recording_timestamp, perf_q): @trace(logger) def process_events( event_q: queue.Queue, - screen_write_q: multiprocessing.Queue, - action_write_q: multiprocessing.Queue, - window_write_q: multiprocessing.Queue, - perf_q: multiprocessing.Queue, + screen_write_q: sq.SynchronizedQueue, + action_write_q: sq.SynchronizedQueue, + window_write_q: sq.SynchronizedQueue, + perf_q: sq.SynchronizedQueue, recording_timestamp: float, terminate_event: multiprocessing.Event, ): @@ -193,7 +197,7 @@ def process_events( def write_action_event( recording_timestamp: float, event: Event, - perf_q: multiprocessing.Queue, + perf_q: sq.SynchronizedQueue, ): """ Write an action event to the database and update the performance queue. @@ -212,7 +216,7 @@ def write_action_event( def write_screen_event( recording_timestamp: float, event: Event, - perf_q: multiprocessing.Queue, + perf_q: sq.SynchronizedQueue, ): """ Write a screen event to the database and update the performance queue. @@ -234,7 +238,7 @@ def write_screen_event( def write_window_event( recording_timestamp: float, event: Event, - perf_q: multiprocessing.Queue, + perf_q: sq.SynchronizedQueue, ): """ Write a window event to the database and update the performance queue. @@ -254,10 +258,11 @@ def write_window_event( def write_events( event_type: str, write_fn: Callable, - write_q: multiprocessing.Queue, - perf_q: multiprocessing.Queue, + write_q: sq.SynchronizedQueue, + perf_q: sq.SynchronizedQueue, recording_timestamp: float, terminate_event: multiprocessing.Event, + term_pipe: multiprocessing.Pipe, ): """ Write events of a specific type to the db using the provided write function. @@ -269,13 +274,37 @@ def write_events( perf_q: A queue for collecting performance data. recording_timestamp: The timestamp of the recording. terminate_event: An event to signal the termination of the process. + term_pipe: A pipe for communicating \ + the number of events left to be written. """ utils.configure_logging(logger, LOG_LEVEL) utils.set_start_time(recording_timestamp) logger.info(f"{event_type=} starting") signal.signal(signal.SIGINT, signal.SIG_IGN) - while not terminate_event.is_set() or not write_q.empty(): + + num_left = 0 + progress = None + while ( + not terminate_event.is_set() or + not write_q.empty() + ): + if term_pipe.poll(): + num_left = term_pipe.recv() + if num_left != 0 and progress is None: + progress = tqdm( + total=num_left, + desc="Writing to Database", + unit="event", + colour="green", + dynamic_ncols=True, + ) + if ( + terminate_event.is_set() and + num_left != 0 and + progress is not None + ): + progress.update() try: event = write_q.get_nowait() except queue.Empty: @@ -283,6 +312,10 @@ def write_events( assert event.type == event_type, (event_type, event) write_fn(recording_timestamp, event, perf_q) logger.debug(f"{event_type=} written") + + if progress is not None: + progress.close() + logger.info(f"{event_type=} done") @@ -375,7 +408,8 @@ def handle_key( "vk", ] attrs = { - f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names + f"key_{attr_name}": getattr(key, attr_name, None) + for attr_name in attr_names } logger.debug(f"{attrs=}") canonical_attrs = { @@ -383,7 +417,9 @@ def handle_key( for attr_name in attr_names } logger.debug(f"{canonical_attrs=}") - trigger_action_event(event_q, {"name": event_name, **attrs, **canonical_attrs}) + trigger_action_event( + event_q, {"name": event_name, **attrs, **canonical_attrs} + ) def read_screen_events( @@ -463,7 +499,7 @@ def read_window_events( @trace(logger) def performance_stats_writer( - perf_q: multiprocessing.Queue, + perf_q: sq.SynchronizedQueue, recording_timestamp: float, terminate_event: multiprocessing.Event, ): @@ -660,13 +696,17 @@ def record( recording_timestamp = recording.timestamp event_q = queue.Queue() - screen_write_q = multiprocessing.Queue() - action_write_q = multiprocessing.Queue() - window_write_q = multiprocessing.Queue() + screen_write_q = sq.SynchronizedQueue() + action_write_q = sq.SynchronizedQueue() + window_write_q = sq.SynchronizedQueue() # TODO: save write times to DB; display performance plot in visualize.py - perf_q = multiprocessing.Queue() + perf_q = sq.SynchronizedQueue() terminate_event = multiprocessing.Event() - + + term_pipe_parent_window, term_pipe_child_window = multiprocessing.Pipe() + term_pipe_parent_screen, term_pipe_child_screen = multiprocessing.Pipe() + term_pipe_parent_action, term_pipe_child_action = multiprocessing.Pipe() + window_event_reader = threading.Thread( target=read_window_events, args=(event_q, terminate_event, recording_timestamp), @@ -714,6 +754,7 @@ def record( perf_q, recording_timestamp, terminate_event, + term_pipe_child_screen, ), ) screen_event_writer.start() @@ -727,6 +768,7 @@ def record( perf_q, recording_timestamp, terminate_event, + term_pipe_child_action, ), ) action_event_writer.start() @@ -740,6 +782,7 @@ def record( perf_q, recording_timestamp, terminate_event, + term_pipe_child_window, ), ) window_event_writer.start() @@ -776,9 +819,14 @@ def record( except KeyboardInterrupt: terminate_event.set() + collect_stats() log_memory_usage() + term_pipe_parent_window.send(window_write_q.qsize()) + term_pipe_parent_action.send(action_write_q.qsize()) + term_pipe_parent_screen.send(screen_write_q.qsize()) + logger.info(f"joining...") keyboard_event_reader.join() mouse_event_reader.join() @@ -788,7 +836,6 @@ def record( screen_event_writer.join() action_event_writer.join() window_event_writer.join() - terminate_perf_event.set() if PLOT_PERFORMANCE: diff --git a/openadapt/scripts/scrub.py b/openadapt/scripts/scrub.py index 486f28725..b4058f28c 100644 --- a/openadapt/scripts/scrub.py +++ b/openadapt/scripts/scrub.py @@ -113,6 +113,7 @@ def scrub_mp4( unit="frame", bar_format=progress_bar_format, colour="green", + dynamic_ncols=True, ) progress_interval = 0.1 # Print progress every 10% of frames progress_threshold = math.floor(frame_count * progress_interval) diff --git a/openadapt/visualize.py b/openadapt/visualize.py index 1219f7f5d..9a2e96f3d 100644 --- a/openadapt/visualize.py +++ b/openadapt/visualize.py @@ -8,6 +8,7 @@ from bokeh.layouts import layout, row from bokeh.models.widgets import Div from loguru import logger +from tqdm import tqdm from openadapt.crud import ( get_latest_recording, @@ -188,70 +189,87 @@ def main(): ), ] logger.info(f"{len(action_events)=}") - for idx, action_event in enumerate(action_events): - if idx == MAX_EVENTS: - break - image = display_event(action_event) - diff = display_event(action_event, diff=True) - mask = action_event.screenshot.diff_mask - - if SCRUB: - image = scrub.scrub_image(image) - diff = scrub.scrub_image(diff) - mask = scrub.scrub_image(mask) - - image_utf8 = image2utf8(image) - diff_utf8 = image2utf8(diff) - mask_utf8 = image2utf8(mask) - width, height = image.size - - action_event_dict = row2dict(action_event) - window_event_dict = row2dict(action_event.window_event) - - if SCRUB: - action_event_dict = scrub.scrub_dict(action_event_dict) - window_event_dict = scrub.scrub_dict(window_event_dict) - - rows.append( - [ - row( - Div( - text=f""" -
- - - -
- - {dict2html(window_event_dict , None)} -
- """, - ), - Div( - text=f""" - - {dict2html(action_event_dict)} -
- """ + + num_events = ( + min(MAX_EVENTS, len(action_events)) + if MAX_EVENTS is not None + else len(action_events) + ) + with tqdm( + total=num_events, + desc="Preparing HTML", + unit="event", + colour="green", + dynamic_ncols=True, + ) as progress: + for idx, action_event in enumerate(action_events): + if idx == MAX_EVENTS: + break + image = display_event(action_event) + diff = display_event(action_event, diff=True) + mask = action_event.screenshot.diff_mask + + if SCRUB: + image = scrub.scrub_image(image) + diff = scrub.scrub_image(diff) + mask = scrub.scrub_image(mask) + + image_utf8 = image2utf8(image) + diff_utf8 = image2utf8(diff) + mask_utf8 = image2utf8(mask) + width, height = image.size + + action_event_dict = row2dict(action_event) + window_event_dict = row2dict(action_event.window_event) + + if SCRUB: + action_event_dict = scrub.scrub_dict(action_event_dict) + window_event_dict = scrub.scrub_dict(window_event_dict) + + rows.append( + [ + row( + Div( + text=f""" +
+ + + +
+ + {dict2html(window_event_dict , None)} +
+ """, + ), + Div( + text=f""" + + {dict2html(action_event_dict)} +
+ """ + ), ), - ), - ] - ) + ] + ) + + progress.update() + + progress.close() title = f"recording-{recording.id}" fname_out = f"recording-{recording.id}.html"