Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openadapt): add progress bar in record.py and visualize.py #318

Merged
merged 29 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
29ec8ec
run `poetry lock --no-update`
Jun 26, 2023
2f29934
add alive-progress via poetry and in code
Jun 26, 2023
1ff5e9d
add progress bar in visualization
Jun 26, 2023
790a17e
add a check for MAX_EVENT = None
Jun 26, 2023
37f22c0
update the title for the Progress bAr
Jun 26, 2023
afbaa2c
update the requirement.txt
Jun 26, 2023
7227c5c
ran ` black --line-length 80 <file>`
Jun 26, 2023
0c5f4e9
remove all progress bar from record
Jun 26, 2023
d2adac1
Merge branch 'MLDSAI:main' into feature/record_pb
KrishPatel13 Jun 26, 2023
9e8ee4d
add tqdm progress bar in recrod.py
Jun 27, 2023
3de913f
add tqdm for visualiztion
Jun 27, 2023
e5b1601
remove alive-progress
Jun 27, 2023
f7c42d1
consistent tqdm api
Jun 27, 2023
03dc223
Update requirements.txt
KrishPatel13 Jun 28, 2023
2c26a45
Address comemnt:
Jun 28, 2023
d41d088
remove incorrect indent
Jun 28, 2023
2234eb9
remove rows
Jun 28, 2023
8f6d02b
try to fix distorted table in html
Jun 28, 2023
b2df437
Merge branch 'MLDSAI:main' into feature/record_pb
KrishPatel13 Jun 28, 2023
9c063ef
add custom queue class
Jun 29, 2023
857382a
lint --line-length 80
Jun 29, 2023
c6fbf3a
fix `NotImplementedError` for MacOs
Jun 29, 2023
416ce80
rename custom -> thirdparty_customization
Jun 29, 2023
00fdcc3
rename to something useful
Jun 29, 2023
a612535
address comments
Jun 29, 2023
5204148
rename dir to customized_imports
Jun 29, 2023
e2419cc
rename to extensions
Jun 29, 2023
cb9a3d8
Merge branch 'OpenAdaptAI:main' into feature/record_pb
KrishPatel13 Jun 29, 2023
aa61233
Merge branch 'main' into feature/record_pb
KrishPatel13 Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions openadapt/customized_imports/synchronized_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Module for customizing multiprocessing.Queue to avoid NotImplementedError in MacOS
"""


from multiprocessing.queues import Queue
import multiprocessing

# Credit: https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9

# The following implementation of custom SynchronizedQueue to avoid NotImplementedError
# when calling queue.qsize() in MacOS X comes almost entirely from this github
# discussion: https://github.com/keras-team/autokeras/issues/368
# Necessary modification is made to make the code compatible with Python3.


class SharedCounter(object):
""" A synchronized shared counter.
The locking done by multiprocessing.Value ensures that only a single
process or thread may read or write the in-memory ctypes object. However,
in order to do n += 1, Python performs a read followed by a write, so a
second process may read the old value before the new one is written by the
first process. The solution is to use a multiprocessing.Lock to guarantee
the atomicity of the modifications to Value.
This class comes almost entirely from Eli Bendersky's blog:
http://eli.thegreenplace.net/2012/01/04/
shared-counter-with-pythons-multiprocessing/
"""

def __init__(self, n=0):
self.count = multiprocessing.Value('i', n)

def increment(self, n=1):
""" Increment the counter by n (default = 1) """
with self.count.get_lock():
self.count.value += n

@property
def value(self):
""" Return the value of the counter """
return self.count.value


class SynchronizedQueue(Queue):
""" A portable implementation of multiprocessing.Queue.
Because of multithreading / multiprocessing semantics, Queue.qsize() may
raise the NotImplementedError exception on Unix platforms like Mac OS X
where sem_getvalue() is not implemented. This subclass addresses this
problem by using a synchronized shared counter (initialized to zero) and
increasing / decreasing its value every time the put() and get() methods
are called, respectively. This not only prevents NotImplementedError from
being raised, but also allows us to implement a reliable version of both
qsize() and empty().
Note the implementation of __getstate__ and __setstate__ which help to
serialize SynchronizedQueue when it is passed between processes. If these functions
are not defined, SynchronizedQueue cannot be serialized, which will lead to the error
of "AttributeError: 'SynchronizedQueue' object has no attribute 'size'".
See the answer provided here: https://stackoverflow.com/a/65513291/9723036

For documentation of using __getstate__ and __setstate__
to serialize objects, refer to here:
https://docs.python.org/3/library/pickle.html#pickling-class-instances
"""

def __init__(self):
super().__init__(ctx=multiprocessing.get_context())
self.size = SharedCounter(0)

def __getstate__(self):
"""Help to make SynchronizedQueue instance serializable.
Note that we record the parent class state, which is the state of the
actual queue, and the size of the queue, which is the state of SynchronizedQueue.
self.size is a SharedCounter instance. It is itself serializable.
"""
return {
'parent_state': super().__getstate__(),
'size': self.size,
}

def __setstate__(self, state):
super().__setstate__(state['parent_state'])
self.size = state['size']

def put(self, *args, **kwargs):
super().put(*args, **kwargs)
self.size.increment(1)

def get(self, *args, **kwargs):
item = super().get(*args, **kwargs)
self.size.increment(-1)
return item

def qsize(self):
""" Reliable implementation of multiprocessing.Queue.qsize() """
return self.size.value

def empty(self):
""" Reliable implementation of multiprocessing.Queue.empty() """
return not self.qsize()
88 changes: 67 additions & 21 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

from loguru import logger
from pynput import keyboard, mouse
from tqdm import tqdm
import fire
import mss.tools

from openadapt import config, crud, utils, window
from openadapt.customized_imports import synchronized_queue as sq

import functools

Expand Down Expand Up @@ -59,7 +61,9 @@ def wrapper_logging(*args, **kwargs):
func_kwargs = kwargs_to_str(**kwargs)

if func_kwargs != "":
logger.info(f" -> Enter: {func_name}({func_args}, {func_kwargs})")
logger.info(
f" -> Enter: {func_name}({func_args}, {func_kwargs})"
)
else:
logger.info(f" -> Enter: {func_name}({func_args})")

Expand All @@ -83,10 +87,10 @@ def process_event(event, write_q, write_fn, recording_timestamp, perf_q):
@trace(logger)
def process_events(
event_q: queue.Queue,
screen_write_q: multiprocessing.Queue,
action_write_q: multiprocessing.Queue,
window_write_q: multiprocessing.Queue,
perf_q: multiprocessing.Queue,
screen_write_q: sq.SynchronizedQueue,
action_write_q: sq.SynchronizedQueue,
window_write_q: sq.SynchronizedQueue,
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
):
Expand Down Expand Up @@ -165,7 +169,7 @@ def process_events(
def write_action_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write an action event to the database and update the performance queue.
Expand All @@ -184,7 +188,7 @@ def write_action_event(
def write_screen_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write a screen event to the database and update the performance queue.
Expand All @@ -206,7 +210,7 @@ def write_screen_event(
def write_window_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write a window event to the database and update the performance queue.
Expand All @@ -226,10 +230,11 @@ def write_window_event(
def write_events(
event_type: str,
write_fn: Callable,
write_q: multiprocessing.Queue,
perf_q: multiprocessing.Queue,
write_q: sq.SynchronizedQueue,
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
term_pipe: multiprocessing.Pipe,
):
"""
Write events of a specific type to the db using the provided write function.
Expand All @@ -241,20 +246,48 @@ def write_events(
perf_q: A queue for collecting performance data.
recording_timestamp: The timestamp of the recording.
terminate_event: An event to signal the termination of the process.
term_pipe: A pipe for communicating \
the number of events left to be written.
"""

utils.configure_logging(logger, LOG_LEVEL)
utils.set_start_time(recording_timestamp)
logger.info(f"{event_type=} starting")
signal.signal(signal.SIGINT, signal.SIG_IGN)
while not terminate_event.is_set() or not write_q.empty():

num_left = 0
progress = None
while (
not terminate_event.is_set() or
not write_q.empty()
):
if term_pipe.poll():
num_left = term_pipe.recv()
if num_left != 0 and progress is None:
progress = tqdm(
total=num_left,
desc="Writing to Database",
unit="event",
colour="green",
dynamic_ncols=True,
)
if (
terminate_event.is_set() and
num_left != 0 and
progress is not None
):
progress.update()
try:
event = write_q.get_nowait()
except queue.Empty:
continue
assert event.type == event_type, (event_type, event)
write_fn(recording_timestamp, event, perf_q)
logger.debug(f"{event_type=} written")

if progress is not None:
progress.close()

logger.info(f"{event_type=} done")


Expand Down Expand Up @@ -347,15 +380,18 @@ def handle_key(
"vk",
]
attrs = {
f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names
f"key_{attr_name}": getattr(key, attr_name, None)
for attr_name in attr_names
}
logger.debug(f"{attrs=}")
canonical_attrs = {
f"canonical_key_{attr_name}": getattr(canonical_key, attr_name, None)
for attr_name in attr_names
}
logger.debug(f"{canonical_attrs=}")
trigger_action_event(event_q, {"name": event_name, **attrs, **canonical_attrs})
trigger_action_event(
event_q, {"name": event_name, **attrs, **canonical_attrs}
)


def read_screen_events(
Expand Down Expand Up @@ -433,8 +469,8 @@ def read_window_events(


@trace(logger)
def performance_stats_writer (
perf_q: multiprocessing.Queue,
def performance_stats_writer(
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
):
Expand Down Expand Up @@ -562,13 +598,17 @@ def record(
recording_timestamp = recording.timestamp

event_q = queue.Queue()
screen_write_q = multiprocessing.Queue()
action_write_q = multiprocessing.Queue()
window_write_q = multiprocessing.Queue()
screen_write_q = sq.SynchronizedQueue()
action_write_q = sq.SynchronizedQueue()
window_write_q = sq.SynchronizedQueue()
# TODO: save write times to DB; display performance plot in visualize.py
perf_q = multiprocessing.Queue()
perf_q = sq.SynchronizedQueue()
terminate_event = multiprocessing.Event()


term_pipe_parent_window, term_pipe_child_window = multiprocessing.Pipe()
term_pipe_parent_screen, term_pipe_child_screen = multiprocessing.Pipe()
term_pipe_parent_action, term_pipe_child_action = multiprocessing.Pipe()

window_event_reader = threading.Thread(
target=read_window_events,
args=(event_q, terminate_event, recording_timestamp),
Expand Down Expand Up @@ -616,6 +656,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_screen,
),
)
screen_event_writer.start()
Expand All @@ -629,6 +670,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_action,
),
)
action_event_writer.start()
Expand All @@ -642,6 +684,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_window,
),
)
window_event_writer.start()
Expand All @@ -665,6 +708,10 @@ def record(
except KeyboardInterrupt:
terminate_event.set()

term_pipe_parent_window.send(window_write_q.qsize())
KrishPatel13 marked this conversation as resolved.
Show resolved Hide resolved
term_pipe_parent_action.send(action_write_q.qsize())
term_pipe_parent_screen.send(screen_write_q.qsize())

logger.info(f"joining...")
keyboard_event_reader.join()
mouse_event_reader.join()
Expand All @@ -674,7 +721,6 @@ def record(
screen_event_writer.join()
action_event_writer.join()
window_event_writer.join()

terminate_perf_event.set()

if PLOT_PERFORMANCE:
Expand Down
1 change: 1 addition & 0 deletions openadapt/scripts/scrub.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def scrub_mp4(
unit="frame",
bar_format=progress_bar_format,
colour="green",
dynamic_ncols=True,
)
progress_interval = 0.1 # Print progress every 10% of frames
progress_threshold = math.floor(frame_count * progress_interval)
Expand Down
Loading