diff --git a/README.md b/README.md index 305dec99d..da3fec700 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,38 @@ possible memory leak pointing the cursor and left or right clicking, as described in this [open issue](https://github.com/OpenAdaptAI/OpenAdapt/issues/145) + +### Capturing Browser Events + +To capture (record) browser events in Chrome, follow these steps: + +1. Go to: [Chrome Extension Page](chrome://extensions/) + +2. Enable `Developer mode` (located at the top right): + +![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/c97eb9fb-05d6-465d-85b3-332694556272) + +3. Click `Load unpacked` (located at the top left). + +![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/00c8adf5-074a-4655-b132-fd87644007fc) + +4. Select the `chrome_extension` directory: + +![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/71610ed3-f8d4-431a-9a22-d901127b7b0c) + +5. You should see the following confirmation, indicating that the extension is loaded: + +![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/7ee19da9-37e0-448f-b9ab-08ef99110e85) + +6. Set the flag to `true` if it is currently `false`: + +![image](https://github.com/user-attachments/assets/8eba24a3-7c68-4deb-8fbe-9d03cece1482) + +7. Start recording. Once recording begins, navigate to the Chrome browser, browse some pages, and perform a few clicks. Then, stop the recording and let it complete successfully. + +8. After recording, check the `openadapt.db` table `browser_event`. It should contain all your browser activity logs. You can verify the data's correctness using the `sqlite3` CLI or an extension like `SQLite Viewer` in VS Code to open `data/openadapt.db`. + + ### Visualize Quickly visualize the latest recording you created by running the following command: diff --git a/chrome_extension/background.js b/chrome_extension/background.js new file mode 100644 index 000000000..a747b8669 --- /dev/null +++ b/chrome_extension/background.js @@ -0,0 +1,68 @@ +/** + * @file background.js + * @description Creates a new background script that listens for messages from the content script + * and sends them to a WebSocket server. +*/ + +let socket; +let timeOffset = 0; // Global variable to store the time offset + +/* + * TODO: + * Ideally we read `WS_SERVER_PORT`, `WS_SERVER_ADDRESS` and + * `RECONNECT_TIMEOUT_INTERVAL` from config.py, + * or it gets passed in somehow. +*/ +let RECONNECT_TIMEOUT_INTERVAL = 1000; // ms +let WS_SERVER_PORT = 8765; +let WS_SERVER_ADDRESS = "localhost"; +let WS_SERVER_URL = "ws://" + WS_SERVER_ADDRESS + ":" + WS_SERVER_PORT; + + +function socketSend(socket, message) { + console.log({ message }); + socket.send(JSON.stringify(message)); +} + + +/* + * Function to connect to the WebSocket server. +*/ +function connectWebSocket() { + socket = new WebSocket(WS_SERVER_URL); + + socket.onopen = function() { + console.log("WebSocket connection established"); + }; + + socket.onmessage = function(event) { + console.log("Message from server:", event.data); + const message = JSON.parse(event.data); + }; + + socket.onclose = function(event) { + console.log("WebSocket connection closed", event); + // Reconnect after 5 seconds if the connection is lost + setTimeout(connectWebSocket, RECONNECT_TIMEOUT_INTERVAL); + }; + + socket.onerror = function(error) { + console.error("WebSocket error:", error); + socket.close(); + }; +} + +// Create a connection to the WebSocket server +connectWebSocket(); + +/* Listen for messages from the content script */ +chrome.runtime.onMessage.addListener((message, sender, sendResponse) => { + const tabId = sender.tab.id; + message.tabId = tabId; + if (socket && socket.readyState === WebSocket.OPEN) { + socketSend(socket, message); + sendResponse({ status: "Message sent to WebSocket" }); + } else { + sendResponse({ status: "WebSocket connection not open" }); + } +}); diff --git a/chrome_extension/content.js b/chrome_extension/content.js new file mode 100644 index 000000000..a08daabb8 --- /dev/null +++ b/chrome_extension/content.js @@ -0,0 +1,344 @@ +const DEBUG = true; +const RETURN_FULL_DOCUMENT = false; +const MAX_COORDS = 3; +const SET_SCREEN_COORDS = false; +const elementIdMap = new WeakMap(); +const idToElementMap = new Map(); // Reverse lookup map +let elementIdCounter = 0; +let messageIdCounter = 0; +const pageId = `${Date.now()}-${Math.random()}`; +const coordMappings = { + x: { client: [], screen: [] }, + y: { client: [], screen: [] } +}; + +function trackMouseEvent(event) { + const { clientX, clientY, screenX, screenY } = event; + + const prevCoordMappingsStr = JSON.stringify(coordMappings); + + // Track x-coordinates + updateCoordinateMappings('x', clientX, screenX); + // Track y-coordinates + updateCoordinateMappings('y', clientY, screenY); + + // Ensure only the latest distinct coordinate mappings per dimension are kept + trimMappings(coordMappings.x); + trimMappings(coordMappings.y); + + const coordMappingsStr = JSON.stringify(coordMappings); + if (DEBUG && coordMappingsStr != prevCoordMappingsStr) { + console.log(JSON.stringify(coordMappings)); + } +} + +function updateCoordinateMappings(dim, clientCoord, screenCoord) { + const coordMap = coordMappings[dim]; + + // Check if current event's client coordinate matches any of the existing ones + if (coordMap.client.includes(clientCoord)) { + // Update screen coordinate for the matching client coordinate + coordMap.screen[coordMap.client.indexOf(clientCoord)] = screenCoord; + } else { + // Add new coordinate mapping + coordMap.client.push(clientCoord); + coordMap.screen.push(screenCoord); + } +} + +function trimMappings(coordMap) { + // Keep only the latest distinct coordinate mappings + if (coordMap.client.length > MAX_COORDS) { + coordMap.client.shift(); + coordMap.screen.shift(); + } +} + +function getConversionPoints() { + const { x, y } = coordMappings; + + // Ensure we have at least two points for each dimension + if (x.client.length < 2 || y.client.length < 2) { + return { sxScale: null, syScale: null, sxOffset: null, syOffset: null }; + } + + // Use linear regression or least squares fitting to determine scale factors and offsets + const { scale: sxScale, offset: sxOffset } = fitLinearTransformation(x.client, x.screen); + const { scale: syScale, offset: syOffset } = fitLinearTransformation(y.client, y.screen); + + return { + sxScale, syScale, sxOffset, syOffset + }; +} + +function fitLinearTransformation(clientCoords, screenCoords) { + const n = clientCoords.length; + let sumClient = 0, sumScreen = 0, sumClientSquared = 0, sumClientScreen = 0; + + for (let i = 0; i < n; i++) { + sumClient += clientCoords[i]; + sumScreen += screenCoords[i]; + sumClientSquared += clientCoords[i] * clientCoords[i]; + sumClientScreen += clientCoords[i] * screenCoords[i]; + } + + const scale = (n * sumClientScreen - sumClient * sumScreen) / (n * sumClientSquared - sumClient * sumClient); + const offset = (sumScreen - scale * sumClient) / n; + + return { scale, offset }; +} + +function getScreenCoordinates(element) { + const rect = element.getBoundingClientRect(); + const { top: clientTop, left: clientLeft, bottom: clientBottom, right: clientRight } = rect; + + const conversionPoints = getConversionPoints(); + + // If conversion points are not sufficient, return null coordinates + if (conversionPoints.sxScale === null) { + return { top: null, left: null, bottom: null, right: null }; + } + + const { sxScale, syScale, sxOffset, syOffset } = conversionPoints; + + // Convert element's client bounding box to screen coordinates + const screenTop = syScale * clientTop + syOffset; + const screenLeft = sxScale * clientLeft + sxOffset; + const screenBottom = syScale * clientBottom + syOffset; + const screenRight = sxScale * clientRight + sxOffset; + + return { top: screenTop, left: screenLeft, bottom: screenBottom, right: screenRight }; +} + +function sendMessageToBackgroundScript(message) { + message.id = messageIdCounter++; + message.pageId = pageId; + message.url = window.location.href; + if (DEBUG) { + const messageType = message.type; + const messageLength = JSON.stringify(message).length; + console.log({ messageType, messageLength, message }); + } + chrome.runtime.sendMessage(message); +} + +function generateElementIdAndBbox(element) { + // ignore invisible elements + if (!isVisible(element)) { + return; + } + + // set id + if (!elementIdMap.has(element)) { + const newId = `elem-${elementIdCounter++}`; + elementIdMap.set(element, newId); + idToElementMap.set(newId, element); // Reverse mapping + element.setAttribute('data-id', newId); + } + + // set client bbox + let { top, left, bottom, right } = element.getBoundingClientRect(); + let bboxClient = `${top},${left},${bottom},${right}`; + element.setAttribute('data-tlbr-client', bboxClient); + + // set screen bbox + if (SET_SCREEN_COORDS) { + ({ top, left, bottom, right } = getScreenCoordinates(element)); + if (top == null) { + // not enough data points to get screen coordinates + return + } + let bboxScreen = `${top},${left},${bottom},${right}`; + element.setAttribute('data-tlbr-screen', bboxScreen); + } + + return elementIdMap.get(element); +} + +function instrumentLiveDomWithBbox() { + document.querySelectorAll('*').forEach(element => generateElementIdAndBbox(element)); +} + +function isVisible(element) { + const rect = element.getBoundingClientRect(); + const style = window.getComputedStyle(element); + + return ( + rect.width > 0 && + rect.height > 0 && + rect.bottom >= 0 && + rect.right >= 0 && + rect.top <= (window.innerHeight || document.documentElement.clientHeight) && + rect.left <= (window.innerWidth || document.documentElement.clientWidth) && + style.visibility !== 'hidden' && + style.display !== 'none' && + style.opacity !== '0' + ); +} + +function cleanDomTree(node) { + const children = Array.from(node.childNodes); // Use childNodes to include all types of child nodes + for (const child of children) { + if (child.nodeType === Node.ELEMENT_NODE) { + // Check for img elements with src="data..." + if (child.tagName === 'IMG' && child.hasAttribute('src')) { + const src = child.getAttribute('src'); + if (src.startsWith('data:')) { + //const [metadata] = src.split(','); // Extract the metadata part (e.g., "data:image/jpeg;base64") + //child.setAttribute('src', `${metadata}`); // Replace the data content with "" + // The above triggers net::ERR_INVALID_URL, so just remove it for now + child.setAttribute('src', ''); + } + } + + const originalId = child.getAttribute('data-id'); + if (originalId) { + const originalElement = idToElementMap.get(originalId); + if (!originalElement || !isVisible(originalElement)) { + node.removeChild(child); + } else { + cleanDomTree(child); // Recursive call for child nodes + } + } + } else if (child.nodeType === Node.COMMENT_NODE) { + node.removeChild(child); // Remove comments + } else if (child.nodeType === Node.TEXT_NODE) { + // Strip newlines and whitespace-only text nodes + const trimmedText = child.textContent.replace(/\s+/g, ' ').trim(); + if (trimmedText.length === 0) { + node.removeChild(child); + } else { + child.textContent = trimmedText; // Replace the text content with stripped version + } + } + } +} + +function getVisibleHtmlString() { + const startTime = performance.now(); + + // Step 1: Instrument the live DOM with data-id and data-bbox attributes + instrumentLiveDomWithBbox(); + + if (RETURN_FULL_DOCUMENT) { + const visibleHtmlDuration = performance.now() - startTime; + console.log({ visibleHtmlDuration }); + const visibleHtmlString = document.body.outerHTML; + return { visibleHtmlString, visibleHtmlDuration }; + } + + // Step 2: Clone the body + const clonedBody = document.body.cloneNode(true); + + // Step 3: Remove invisible elements from the cloned DOM + cleanDomTree(clonedBody); + + // Step 4: Serialize the modified clone to a string + const visibleHtmlString = clonedBody.outerHTML; + + const visibleHtmlDuration = performance.now() - startTime; + console.log({ visibleHtmlDuration }); + + return { visibleHtmlString, visibleHtmlDuration }; +} + +/** + * Validates MouseEvent coordinates against bounding boxes for both client and screen. + * @param {MouseEvent} event - The mouse event containing coordinates. + * @param {HTMLElement} eventTarget - The target element of the mouse event. + * @param {string} attrType - The type of attribute to validate ('client' or 'screen'). + * @param {string} coordX - The X coordinate to validate (clientX or screenX). + * @param {string} coordY - The Y coordinate to validate (clientY or screenY). + */ +function validateCoordinates(event, eventTarget, attrType, coordX, coordY) { + const attr = `data-tlbr-${attrType}` + const bboxAttr = eventTarget.getAttribute(attr); + if (!bboxAttr) { + console.warn(`${attr} is empty`); + return; + } + const [top, left, bottom, right] = bboxAttr.split(',').map(parseFloat); + const x = event[coordX]; + const y = event[coordY]; + + if (x < left || x > right || y < top || y > bottom) { + console.warn(`${attrType} coordinates outside:`, JSON.stringify({ + [coordX]: x, + [coordY]: y, + bbox: { top, left, bottom, right }, + })); + console.log(JSON.stringify({ devicePixelRatio, innerHeight, innerWidth, outerHeight, outerWidth, scrollY, scrollX, pageXOffset, pageYOffset, screenTop, screenLeft })); + } else { + console.log(`${attrType} coordinates inside:`, JSON.stringify({ + [coordX]: x, + [coordY]: y, + bbox: { top, left, bottom, right }, + })); + } +} + +function handleUserGeneratedEvent(event) { + const eventTarget = event.target; + const eventTargetId = generateElementIdAndBbox(eventTarget); + const timestamp = Date.now() / 1000; // Convert to Python-compatible seconds + + const { visibleHtmlString, visibleHtmlDuration } = getVisibleHtmlString(); + + const eventData = { + type: 'USER_EVENT', + eventType: event.type, + targetId: eventTargetId, + timestamp: timestamp, + visibleHtmlString, + visibleHtmlDuration, + }; + + if (event instanceof KeyboardEvent) { + eventData.key = event.key; + eventData.code = event.code; + } else if (event instanceof MouseEvent) { + eventData.clientX = event.clientX; + eventData.clientY = event.clientY; + eventData.screenX = event.screenX; + eventData.screenY = event.screenY; + eventData.button = event.button; + eventData.coordMappings = coordMappings; + validateCoordinates(event, eventTarget, 'client', 'clientX', 'clientY'); + if (SET_SCREEN_COORDS) { + validateCoordinates(event, eventTarget, 'screen', 'screenX', 'screenY'); + } + } + sendMessageToBackgroundScript(eventData); +} + +// Attach event listeners for user-generated events +function attachUserEventListeners() { + const eventsToCapture = [ + 'click', + // input events are triggered after the DOM change is written, so we can't use them + // (since the resulting HTML would not look as the DOM was at the time the + // user took the action, i.e. immediately before) + //'input', + 'keydown', + 'keyup', + ]; + + eventsToCapture.forEach(eventType => { + document.body.addEventListener(eventType, handleUserGeneratedEvent, true); + }); +} + +function attachInstrumentationEventListeners() { + const eventsToCapture = [ + 'mousedown', + 'mouseup', + 'mousemove', + ]; + eventsToCapture.forEach(eventType => { + document.body.addEventListener(eventType, trackMouseEvent, true); + }); +} + +// Initial setup +attachUserEventListeners(); +attachInstrumentationEventListeners(); diff --git a/chrome_extension/icons/logo.png b/chrome_extension/icons/logo.png new file mode 100644 index 000000000..64aa029bd Binary files /dev/null and b/chrome_extension/icons/logo.png differ diff --git a/chrome_extension/manifest.json b/chrome_extension/manifest.json new file mode 100644 index 000000000..e75b26374 --- /dev/null +++ b/chrome_extension/manifest.json @@ -0,0 +1,23 @@ +{ + "name": "openadapt", + "description": "Uses sockets to expose DOM events to OpenAdapt", + "version": "1.0", + "manifest_version": 3, + "icons": { + "48": "icons/logo.png" + }, + "action": { + "default_icon": "icons/logo.png" + }, + "background": { + "service_worker": "background.js" + }, + "permissions": ["activeTab", "tabs", "scripting"], + "host_permissions": [""], + "content_scripts": [ + { + "matches": [""], + "js": ["content.js"] + } + ] +} diff --git a/openadapt/alembic/env.py b/openadapt/alembic/env.py index 3b3355859..7a38305dd 100644 --- a/openadapt/alembic/env.py +++ b/openadapt/alembic/env.py @@ -10,6 +10,7 @@ from alembic import context from openadapt.config import config from openadapt.db import db +from openadapt.models import ForceFloat # This is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -37,6 +38,11 @@ def get_url() -> str: return config.DB_URL +def process_revision_directives(context, revision, directives): + script = directives[0] + script.imports.add("import openadapt") + + def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -55,6 +61,7 @@ def run_migrations_offline() -> None: literal_binds=True, dialect_opts={"paramstyle": "named"}, render_as_batch=True, + process_revision_directives=process_revision_directives, ) with context.begin_transaction(): @@ -80,6 +87,7 @@ def run_migrations_online() -> None: connection=connection, target_metadata=target_metadata, render_as_batch=True, + process_revision_directives=process_revision_directives, ) with context.begin_transaction(): diff --git a/openadapt/alembic/versions/98505a067995_add_browserevent_table.py b/openadapt/alembic/versions/98505a067995_add_browserevent_table.py new file mode 100644 index 000000000..6092f394c --- /dev/null +++ b/openadapt/alembic/versions/98505a067995_add_browserevent_table.py @@ -0,0 +1,46 @@ +"""add BrowserEvent table + +Revision ID: 98505a067995 +Revises: bb25e889ad71 +Create Date: 2024-08-28 16:51:10.592340 + +""" +from alembic import op +import sqlalchemy as sa +import openadapt + +# revision identifiers, used by Alembic. +revision = '98505a067995' +down_revision = 'bb25e889ad71' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('browser_event', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('recording_timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True), + sa.Column('recording_id', sa.Integer(), nullable=True), + sa.Column('message', sa.JSON(), nullable=True), + sa.Column('timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True), + sa.ForeignKeyConstraint(['recording_id'], ['recording.id'], name=op.f('fk_browser_event_recording_id_recording')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_browser_event')) + ) + with op.batch_alter_table('action_event', schema=None) as batch_op: + batch_op.add_column(sa.Column('browser_event_timestamp', openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True)) + batch_op.add_column(sa.Column('browser_event_id', sa.Integer(), nullable=True)) + batch_op.create_foreign_key(batch_op.f('fk_action_event_browser_event_id_browser_event'), 'browser_event', ['browser_event_id'], ['id']) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('action_event', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('fk_action_event_browser_event_id_browser_event'), type_='foreignkey') + batch_op.drop_column('browser_event_id') + batch_op.drop_column('browser_event_timestamp') + + op.drop_table('browser_event') + # ### end Alembic commands ### diff --git a/openadapt/app/dashboard/app/settings/record_and_replay/form.tsx b/openadapt/app/dashboard/app/settings/record_and_replay/form.tsx index c72211a71..036ae26bf 100644 --- a/openadapt/app/dashboard/app/settings/record_and_replay/form.tsx +++ b/openadapt/app/dashboard/app/settings/record_and_replay/form.tsx @@ -43,6 +43,9 @@ export const Form = ({ + + + diff --git a/openadapt/app/dashboard/components/ActionEvent/ActionEvent.tsx b/openadapt/app/dashboard/components/ActionEvent/ActionEvent.tsx index 066f31cdb..69dc506c3 100644 --- a/openadapt/app/dashboard/components/ActionEvent/ActionEvent.tsx +++ b/openadapt/app/dashboard/components/ActionEvent/ActionEvent.tsx @@ -65,6 +65,12 @@ export const ActionEvent = ({ {timeStampToDateString(event.screenshot_timestamp)} )} + {event.browser_event_timestamp && ( + + browser event timestamp + {timeStampToDateString(event.browser_event_timestamp)} + + )} window event timestamp {timeStampToDateString(event.window_event_timestamp)} diff --git a/openadapt/app/dashboard/types/action-event.ts b/openadapt/app/dashboard/types/action-event.ts index 211aa3f78..f99757bbb 100644 --- a/openadapt/app/dashboard/types/action-event.ts +++ b/openadapt/app/dashboard/types/action-event.ts @@ -5,6 +5,7 @@ export type ActionEvent = { recording_timestamp: number; screenshot_timestamp?: number; window_event_timestamp: number; + browser_event_timestamp: number; mouse_x: number | null; mouse_y: number | null; mouse_dx: number | null; diff --git a/openadapt/browser.py b/openadapt/browser.py new file mode 100644 index 000000000..aa8f1cd6b --- /dev/null +++ b/openadapt/browser.py @@ -0,0 +1,852 @@ +"""Utilities for working with BrowserEvents.""" + +from statistics import mean, median, stdev + +from bs4 import BeautifulSoup +from copy import deepcopy +from dtaidistance import dtw, dtw_ndim +from loguru import logger +from sqlalchemy.orm import Session as SaSession +from tqdm import tqdm +import numpy as np + +from openadapt import models, utils +from openadapt.db import crud + +# action to browser +MOUSE_BUTTON_MAPPING = {"left": 0, "right": 2, "middle": 1} + +# action to browser +EVENT_TYPE_MAPPING = {"click": "click", "press": "keydown", "release": "keyup"} + +SPATIAL = True + +# TODO: read from pynput +KEYBOARD_KEYS = [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "tab", + "enter", + "shift", + "ctrl", + "alt", + "esc", + "space", + "backspace", + "delete", + "home", + "end", + "pageup", + "pagedown", + "arrowup", + "arrowdown", + "arrowleft", + "arrowright", +] + + +def add_screen_tlbr(browser_events: list[models.BrowserEvent]) -> None: + """Computes and adds the 'data-tlbr-screen' attribute for each element. + + Uses coordMappings provided by JavaScript events. If 'data-tlbr-screen' already + exists, compute the values again and assert equality. Reuse the most recent valid + mappings if none exist for the current event by iterating over the events in + reverse order. + + Args: + browser_events (list[models.BrowserEvent]): list of browser events to process. + """ + # Initialize variables to store the most recent valid mappings + latest_valid_x_mappings = None + latest_valid_y_mappings = None + + # Iterate over the events in reverse order + for event in reversed(browser_events): + message = event.message + + event_type = message.get("eventType") + if event_type != "click": + continue + + visible_html_string = message.get("visibleHtmlString") + if not visible_html_string: + logger.warning("No visible HTML data available for event.") + continue + + # Parse the visible HTML using BeautifulSoup + soup = BeautifulSoup(visible_html_string, "html.parser") + + # Fetch the target element using its data-id + target_id = message.get("targetId") + target_element = soup.find(attrs={"data-id": target_id}) + + if not target_element: + logger.warning(f"No target element found for targetId: {target_id}") + continue + + # Extract coordMappings from the message + coord_mappings = message.get("coordMappings", {}) + x_mappings = coord_mappings.get("x", {}) + y_mappings = coord_mappings.get("y", {}) + + # Check if there are sufficient data points; if not, reuse latest valid mappings + if ( + "client" in x_mappings + and len(x_mappings["client"]) >= 2 + and len(y_mappings["client"]) >= 2 + ): + # Update the latest valid mappings + latest_valid_x_mappings = x_mappings + latest_valid_y_mappings = y_mappings + else: + # Reuse the most recent valid mappings + if latest_valid_x_mappings is None or latest_valid_y_mappings is None: + logger.warning( + f"No valid coordinate mappings available for element: {target_id}" + ) + continue # No valid mappings available, skip this event + + x_mappings = latest_valid_x_mappings + y_mappings = latest_valid_y_mappings + + # Compute the scale and offset for both x and y axes + sx_scale, sx_offset = fit_linear_transformation( + x_mappings["client"], x_mappings["screen"] + ) + sy_scale, sy_offset = fit_linear_transformation( + y_mappings["client"], y_mappings["screen"] + ) + + # Only process the screen coordinates + tlbr_attr = "data-tlbr-screen" + try: + # Get existing screen coordinates if present + existing_screen_coords = ( + target_element[tlbr_attr] if tlbr_attr in target_element.attrs else None + ) + except KeyError: + existing_screen_coords = None + + # Compute screen coordinates + client_coords = target_element.get("data-tlbr-client") + if not client_coords: + logger.warning( + f"Missing client coordinates for element with id: {target_id}" + ) + continue + + # Extract client coordinates + client_top, client_left, client_bottom, client_right = map( + float, client_coords.split(",") + ) + + # Calculate screen coordinates using the computed scale and offset + screen_top = sy_scale * client_top + sy_offset + screen_left = sx_scale * client_left + sx_offset + screen_bottom = sy_scale * client_bottom + sy_offset + screen_right = sx_scale * client_right + sx_offset + + # New computed screen coordinates + new_screen_coords = f"{screen_top},{screen_left},{screen_bottom},{screen_right}" + logger.info(f"{client_coords=} {existing_screen_coords=} {new_screen_coords=}") + + # Check for existing data-tlbr-screen attribute + if existing_screen_coords: + assert existing_screen_coords == new_screen_coords, ( + "Mismatch in computed and existing screen coordinates:" + f" {existing_screen_coords} != {new_screen_coords}" + ) + + # Update the attribute with the new value + target_element["data-tlbr-screen"] = new_screen_coords + + # Write the updated element back to the message + message["visibleHtmlString"] = str(soup) + + logger.info("Finished processing all browser events for screen coordinates.") + + +def fit_linear_transformation( + client_coords: list[float], screen_coords: list[float] +) -> tuple[float, float]: + """Fit a linear transformation (scale and offset) from client to screen coordinates. + + Args: + client_coords (list[float]): The client coordinates (x or y). + screen_coords (list[float]): The corresponding screen coordinates (x or y). + + Returns: + tuple[float, float]: The scale and offset values for the linear transformation. + """ + n = len(client_coords) + sum_client = sum(client_coords) + sum_screen = sum(screen_coords) + sum_client_squared = sum(c**2 for c in client_coords) + sum_client_screen = sum(c * s for c, s in zip(client_coords, screen_coords)) + + # Calculate scale and offset using least squares fitting + scale = (n * sum_client_screen - sum_client * sum_screen) / ( + n * sum_client_squared - sum_client**2 + ) + offset = (sum_screen - scale * sum_client) / n + + return scale, offset + + +def identify_and_log_smallest_clicked_element( + browser_event: models.BrowserEvent, +) -> None: + """Logs the smallest DOM element that was clicked on for a given click event. + + Args: + browser_event: The browser event containing the click details. + """ + visible_html_string = browser_event.message.get("visibleHtmlString") + message_id = browser_event.message.get("id") + logger.info("*" * 10) + logger.info(f"{message_id=}") + target_id = browser_event.message.get("targetId") + logger.info(f"{target_id=}") + + if not visible_html_string: + logger.warning("No visible HTML data available for click event.") + return + + # Parse the visible HTML using BeautifulSoup + soup = BeautifulSoup(visible_html_string, "html.parser") + target_element = soup.find(attrs={"data-id": target_id}) + target_area = None + if not target_element: + logger.warning(f"{target_element=}") + else: + for coord_type in ("client", "screen"): + x = browser_event.message.get(f"{coord_type}X") + y = browser_event.message.get(f"{coord_type}Y") + tlbr = f"data-tlbr-{coord_type}" + try: + target_element_tlbr = target_element[tlbr] + except KeyError: + logger.warning(f"{tlbr=} not in {target_element=}") + continue + top, left, bottom, right = map(float, target_element_tlbr.split(",")) + logger.info(f"{tlbr=} {x=} {y=} {top=} {left=} {bottom=} {right=}") + if not (left <= x <= right and top <= y <= bottom): + logger.warning("outside") + + # Calculate the area for target_element + if "data-tlbr-client" in target_element.attrs: + target_top, target_left, target_bottom, target_right = map( + float, target_element["data-tlbr-client"].split(",") + ) + target_area = (target_right - target_left) * (target_bottom - target_top) + + elements = soup.find_all(attrs={"data-tlbr-client": True}) + + smallest_element = None + smallest_area = float("inf") + + for elem in elements: + data_tlbr = elem["data-tlbr-client"] + top, left, bottom, right = map(float, data_tlbr.split(",")) + client_x = browser_event.message.get("clientX") + client_y = browser_event.message.get("clientY") + + if left <= client_x <= right and top <= client_y <= bottom: + area = (right - left) * (bottom - top) + if area < smallest_area: + smallest_area = area + smallest_element = elem + + if smallest_element is not None: + smallest_element_str = utils.truncate_html(str(smallest_element), 100) + logger.info(f"Smallest clicked element found: {smallest_element_str}") + smallest_element_id = smallest_element["data-id"] + smallest_element_type = smallest_element.name + target_element_type = target_element.name if target_element else "Unknown" + smallest_element_area = smallest_area + + # Check if smallest_element is a descendant or ancestor of target_element + is_descendant = False + is_ancestor = False + + if target_element: + is_descendant = target_element in smallest_element.parents + is_ancestor = smallest_element in target_element.parents + + # Log a warning if the smallest element is not the target, + # or a descendant/ancestor of the target + if not (smallest_element_id == target_id or is_descendant or is_ancestor): + logger.warning( + f"{smallest_element_id=} {smallest_element_type=}" + f" {smallest_element_area=} does not match" + f" {target_id=} {target_element_type=} {target_area=}" + f" is_descendant={is_descendant} is_ancestor={is_ancestor}" + ) + else: + logger.warning("No element found matching the click coordinates.") + + +def is_action_event( + event: models.ActionEvent, + action_name: str, + key_or_button: str, +) -> bool: + """Determine if the event matches the given action name and key/button. + + Args: + event: The action event to check. + action_name: The type of action (e.g., "click", "press", "release"). + key_or_button: The key or button associated with the action. + + Returns: + bool: True if the event matches the action name and key/button, False otherwise. + """ + if action_name == "click": + return event.name == action_name and event.mouse_button_name == key_or_button + elif action_name in {"press", "release"}: + raw_action_text = event._text(name_prefix="", name_suffix="") + return event.name == action_name and raw_action_text == key_or_button + else: + return False + + +def is_browser_event( + event: models.ActionEvent, + action_name: str, + key_or_button: str, +) -> bool: + """Determine if the browser event matches the given action name and key/button. + + Args: + event: The browser event to check. + action_name (str): The type of action (e.g., "click", "press", "release"). + key_or_button (str): The key or button associated with the action. + + Returns: + bool: True if the event matches the action name and key/button, False otherwise. + """ + if action_name == "click": + return ( + event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] + and event.message.get("button") == MOUSE_BUTTON_MAPPING[key_or_button] + ) + elif action_name in {"press", "release"}: + return ( + event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] + and event.message.get("key").lower() == key_or_button + ) + else: + return False + + +def align_events( + event_type: str, + action_events: list, + browser_events: list, + spatial: bool = SPATIAL, + use_local_timestamps: bool = False, +) -> list[tuple[int, int]]: + """Align action events and browser events based on timestamps and spatial data. + + Uses Dynamic Time Warping (DTW). + + Args: + event_type (str): The type of event to align. + action_events (list): The list of action events. + browser_events (list): The list of browser events. + spatial (bool, optional): Whether to use spatial data (mouse coordinates). + Defaults to True. + use_local_timestamps (bool, optional): Whether to use local timestamps for + alignment. Defaults to False. + + Returns: + list[tuple[int, int]]: The list of tuples representing aligned event indices. + """ + # Only log if there are any action or browser events of each type + if not action_events and not browser_events: + return [] + + # Convert series of events to timestamps + action_timestamps = [e.timestamp for e in action_events] + if use_local_timestamps: + browser_timestamps = [e.timestamp for e in browser_events] + else: + browser_timestamps = [e.message["timestamp"] for e in browser_events] + + if spatial: + # Prepare sequences for multidimensional DTW + action_sequence = np.array( + [[e.timestamp, e.mouse_x or 0.0, e.mouse_y or 0.0] for e in action_events], + dtype=np.double, + ) + + browser_sequence = np.array( + [ + [ + e.timestamp, + e.message.get("screenX", 0.0), + e.message.get("screenY", 0.0), + ] + for e in browser_events + ], + dtype=np.double, + ) + + # Compute the alignment using multidimensional DTW + path = dtw_ndim.warping_path(action_sequence, browser_sequence) + else: + # Compute the alignment using DTW + path = dtw.warping_path(action_timestamps, browser_timestamps) + + # Enforce a one-to-one correspondence by selecting the closest matches + filtered_path = enforce_one_to_one_mapping( + path, action_timestamps, browser_timestamps + ) + + return filtered_path + + +def evaluate_alignment( + filtered_path: list[tuple[int, int]], + event_type: str, + action_events: list, + browser_events: list, + spatial: bool = SPATIAL, +) -> tuple[int, list[float], list[float], list[float], list[float]]: + """Evaluate the alignment between action events and browser events. + + Args: + filtered_path (list[tuple[int, int]]): The filtered DTW path representing + aligned events. + event_type (str): The type of event being aligned. + action_events (list): The list of action events. + browser_events (list): The list of browser events. + spatial (bool, optional): Whether to use spatial data (mouse coordinates). + Defaults to True. + + Returns: + tuple: A tuple containing: + - int: The total number of errors encountered. + - list[float]: Differences in remote timestamps for matched events. + - list[float]: Differences in local timestamps for matched events. + - list[float]: Differences in mouse X positions for matched events. + - list[float]: Differences in mouse Y positions for matched events. + """ + match_count = 0 + mismatch_count = 0 + error_count = 0 # Initialize error counter + remote_time_differences = ( + [] + ) # To store differences in local/remote timestamps for matching events + local_time_differences = [] # As above but for local/local + mouse_x_differences = ( + [] + ) # To store differences in mouse X positions for matching events + mouse_y_differences = ( + [] + ) # To store differences in mouse Y positions for matching events + + logger.info(f"Alignment for {event_type} Events") + for i, j in filtered_path: + action_event = action_events[i] + browser_event = deepcopy(browser_events[j]) + + action_event_type = action_event.name.lower() + browser_event_type = browser_event.message["eventType"].lower() + + if ( + action_event_type in EVENT_TYPE_MAPPING + and browser_event_type == EVENT_TYPE_MAPPING[action_event_type] + ): + match_count += 1 + remote_time_difference = ( + action_event.timestamp - browser_event.message["timestamp"] + ) + remote_time_differences.append(remote_time_difference) + local_time_difference = action_event.timestamp - browser_event.timestamp + local_time_differences.append(local_time_difference) + + # Compute differences between mouse positions + if action_event.mouse_x is not None: + mouse_x_difference = ( + action_event.mouse_x - browser_event.message["screenX"] + ) + mouse_y_difference = ( + action_event.mouse_y - browser_event.message["screenY"] + ) + if mouse_x_difference > 1: + logger.warning( + f"{mouse_x_difference=} {action_event.mouse_x=}" + f" {browser_event.message['screenX']=}" + ) + if mouse_y_difference > 1: + logger.warning( + f"{mouse_y_difference=} {action_event.mouse_y=}" + f" {browser_event.message['screenY']=}" + ) + mouse_x_differences.append(mouse_x_difference) + mouse_y_differences.append(mouse_y_difference) + else: + mismatch_count += 1 + logger.warning( + f"Event type mismatch: Action({action_event_type}) does not match" + f" Browser({browser_event_type})" + ) + + logger.info( + "\nAction Event:\n" + f" - Type: {action_event.name}\n" + f" - Timestamp: {action_event.timestamp}\n" + f" - Details: {action_event}\n" + "Browser Event:\n" + f" - Type: {browser_event.message['eventType']}\n" + f" - Timestamp: {browser_event.message['timestamp']}\n" + f" - Details: {browser_event}\n" + f"{'-'*80}" + ) + + logger.info(f"Total Matches: {match_count}") + logger.info(f"Total Mismatches: {mismatch_count}") + + # Log unmatched browser events + matched_browser_indices = {j for _, j in filtered_path} + unmatched_browser_events = [ + e for idx, e in enumerate(browser_events) if idx not in matched_browser_indices + ] + + if unmatched_browser_events: + logger.warning(f"Unmatched Browser Events: {len(unmatched_browser_events)}") + for browser_event in unmatched_browser_events: + logger.warning( + "Unmatched Browser Event:\n" + f" - Type: {browser_event.message['eventType']}\n" + f" - Timestamp: {browser_event.message['timestamp']}\n" + f" - Details: {browser_event}\n" + ) + error_count += 1 # Increment error count for each unmatched browser event + + try: + assert ( + len(browser_events) == match_count + ), "Every BrowserEvent should have a corresponding ActionEvent." + except Exception as exc: + error_count += 1 # Increment error count for assertion error + logger.warning(exc) + + return ( + error_count, + remote_time_differences, + local_time_differences, + mouse_x_differences, + mouse_y_differences, + ) + + +def enforce_one_to_one_mapping( + path: list[tuple[int, int]], + action_timestamps: list[float], + browser_timestamps: list[float], +) -> list[tuple[int, int]]: + """Enforce one-to-one mapping between Browser/Action by selecting the closest match. + + Args: + path: list of tuples representing the DTW path. + action_timestamps: list of timestamps for action events. + browser_timestamps: list of timestamps for browser events. + + Returns: + filtered_path: list of tuples representing the filtered DTW path with + one-to-one mapping. + """ + used_action_indices = set() + filtered_path = [] + + # Create a dictionary to store the closest match for each browser event + closest_matches = {} + + for i, j in path: + if j not in closest_matches: + closest_matches[j] = (i, abs(action_timestamps[i] - browser_timestamps[j])) + else: + # If a closer match is found, update closest match for this browser event + current_diff = abs(action_timestamps[i] - browser_timestamps[j]) + if current_diff < closest_matches[j][1]: + closest_matches[j] = (i, current_diff) + + # Collect the closest matches while ensuring each action event is used only once + for j, (i, _) in closest_matches.items(): + if i not in used_action_indices: + filtered_path.append((i, j)) + used_action_indices.add(i) + + return filtered_path + + +def assign_browser_events( + session: SaSession, + action_events: list[models.ActionEvent], + browser_events: list[models.BrowserEvent], +) -> dict: + """Assign browser events to action events by aligning timestamps/types. + + Args: + session (sa.orm.Session): The database session. + action_events (list[models.ActionEvent]): list of action events to assign. + browser_events (list[models.BrowserEvent]): list of browser events to assign. + + Returns: + dict: A dictionary containing statistics and information about the event + assignments. + """ + # Filter BrowserEvents for 'USER_EVENT' type + browser_events = [ + browser_event + for browser_event in browser_events + if browser_event.message["type"] == "USER_EVENT" + ] + + add_screen_tlbr(browser_events) + + # Define event pairs dynamically for mouse events + event_pairs = [ + ("Left Click", "click", "left"), + ("Right Click", "click", "right"), + ("Middle Click", "click", "middle"), + ] + + # Add keyboard events dynamically + event_pairs.extend( + [(f"Key Press {key.upper()}", "press", key) for key in KEYBOARD_KEYS] + + [(f"Key Release {key.upper()}", "release", key) for key in KEYBOARD_KEYS] + ) + + # Initialize statistics + total_errors = 0 + total_remote_time_differences = [] + total_local_time_differences = [] + total_mouse_x_differences = [] + total_mouse_y_differences = [] + + # Initialize additional statistics + event_stats = { + "match_count": 0, + "mismatch_count": 0, + "unmatched_browser_events": 0, + "timestamp_stats": {}, + "mouse_position_stats": {}, + } + + # Process each event pair + for event_type, action_name, key_or_button in tqdm(event_pairs): + action_filtered_events = list( + filter( + lambda e: is_action_event(e, action_name, key_or_button), action_events + ) + ) + browser_filtered_events = list( + filter( + lambda e: is_browser_event(e, action_name, key_or_button), + browser_events, + ) + ) + + if action_filtered_events or browser_filtered_events: + logger.info( + f"{event_type}: {len(action_filtered_events)} action events," + f" {len(browser_filtered_events)} browser events" + ) + + filtered_path = align_events( + event_type, action_filtered_events, browser_filtered_events + ) + + # Assign the closest browser event to each action event + for i, j in filtered_path: + action_event = action_filtered_events[i] + browser_event = browser_filtered_events[j] + action_event.browser_event_timestamp = browser_event.timestamp + action_event.browser_event_id = browser_event.id + logger.info( + f"assigning {action_event.timestamp=} ==>" + f" {browser_event.timestamp=}" + ) + + # Add the updated ActionEvent to the session + session.add(action_event) + + ( + errors, + remote_time_differences, + local_time_differences, + mouse_x_differences, + mouse_y_differences, + ) = evaluate_alignment( + filtered_path, + event_type, + action_filtered_events, + browser_filtered_events, + ) + + # Accumulate statistics + total_errors += errors + total_remote_time_differences += remote_time_differences + total_local_time_differences += local_time_differences + total_mouse_x_differences += mouse_x_differences + total_mouse_y_differences += mouse_y_differences + + event_stats["match_count"] += len(filtered_path) + event_stats["mismatch_count"] += errors + + for browser_event in browser_filtered_events: + if browser_event.message["eventType"] == "click": + identify_and_log_smallest_clicked_element(browser_event) + + # Calculate and log statistics for timestamp differences + event_stats["timestamp_stats"] = {} + for name, time_differences in ( + ("Remote", total_remote_time_differences), + ("Local", total_local_time_differences), + ): + if not time_differences: + logger.warning(f"{name=} {time_differences=}") + continue + min_diff = min(time_differences, key=abs) + max_diff = max(time_differences, key=abs) + mean_diff = mean(time_differences) + median_diff = median(time_differences) + stddev_diff = stdev(time_differences) if len(time_differences) > 1 else 0 + event_stats["timestamp_stats"][name] = { + "min": min_diff, + "max": max_diff, + "mean": mean_diff, + "median": median_diff, + "stddev": stddev_diff, + } + logger.info( + f"{name} Timestamp Differences - Min: {min_diff:.4f}, Max: {max_diff:.4f}," + f" Mean: {mean_diff:.4f}, Median: {median_diff:.4f}, Std Dev:" + f" {stddev_diff:.4f}" + ) + + # Calculate and log statistics for mouse position differences + event_stats["mouse_position_stats"] = {} + for axis, mouse_differences in ( + ("X", total_mouse_x_differences), + ("Y", total_mouse_y_differences), + ): + if not mouse_differences: + logger.warning(f"{axis=} {mouse_differences=}") + continue + min_mouse_diff = min(mouse_differences, key=abs) + max_mouse_diff = max(mouse_differences, key=abs) + num_mouse_errors = sum([abs(diff) >= 1 for diff in mouse_differences]) + num_mouse_correct = sum([abs(diff) < 1 for diff in mouse_differences]) + event_stats["mouse_position_stats"][axis] = { + "min": min_mouse_diff, + "max": max_mouse_diff, + "mean": mean(mouse_differences), + "median": median(mouse_differences), + "stddev": stdev(mouse_differences) if len(mouse_differences) > 1 else 0, + "num_errors": num_mouse_errors, + "num_correct": num_mouse_correct, + } + logger.info(f"{num_mouse_errors=} {num_mouse_correct=}") + if abs(max_mouse_diff) >= 1: + logger.warning(f"abs({max_mouse_diff=}) > 1") + logger.info( + f"Mouse {axis} Position Differences - Min: {min_mouse_diff:.4f}, Max:" + f" {max_mouse_diff:.4f}, Mean:" + f" {event_stats['mouse_position_stats'][axis]['mean']:.4f}, Median:" + f" {event_stats['mouse_position_stats'][axis]['median']:.4f}, Std Dev:" + f" {event_stats['mouse_position_stats'][axis]['stddev']:.4f}" + ) + + event_stats["unmatched_browser_events"] = len( + [ + e + for idx, e in enumerate(browser_events) + if idx not in {j for _, j in filtered_path} + ] + ) + + logger.info(f"Total Errors Across All Events: {total_errors}") + return event_stats + + +def log_stats(event_stats: dict) -> None: + """Logs statistics for event assignment. + + Args: + event_stats (dict): A dictionary containing statistics about event assignments. + """ + # Log general event statistics + logger.info(f"{event_stats['match_count']=}") + logger.info(f"{event_stats['mismatch_count']=}") + logger.info(f"{event_stats['unmatched_browser_events']=}") + + # Log timestamp differences statistics + for name, stats in event_stats["timestamp_stats"].items(): + logger.info( + f"{name} - {stats['min']=:.4f}, {stats['max']=:.4f}, " + f"{stats['mean']=:.4f}, {stats['median']=:.4f}, {stats['stddev']=:.4f}" + ) + + # Log mouse position differences statistics + for axis, stats in event_stats["mouse_position_stats"].items(): + logger.info( + f"Mouse {axis} - {stats['min']=:.4f}, {stats['max']=:.4f}, " + f"{stats['mean']=:.4f}, {stats['median']=:.4f}, {stats['stddev']=:.4f}, " + f"{stats['num_errors']=}, {stats['num_correct']=}" + ) + + +def main() -> None: + """Run alignment on the latest recording.""" + session = crud.get_new_session(read_and_write=True) + recording = crud.get_latest_recording(session) + action_events = crud.get_action_events(session=session, recording=recording) + browser_events = crud.get_browser_events(session=session, recording=recording) + + # Get statistics by assigning browser events to action events + event_stats = assign_browser_events(session, action_events, browser_events) + + # Log the statistics using the new function + log_stats(event_stats) + + +if __name__ == "__main__": + main() diff --git a/openadapt/common.py b/openadapt/common.py index 1b5b0c562..e9debeb31 100644 --- a/openadapt/common.py +++ b/openadapt/common.py @@ -10,7 +10,7 @@ "doubleclick", ) MOUSE_EVENTS = tuple(list(RAW_MOUSE_EVENTS) + list(FUSED_MOUSE_EVENTS)) -MOUSE_CLICK_EVENTS = (event for event in MOUSE_EVENTS if "click" in event) +MOUSE_CLICK_EVENTS = (event for event in MOUSE_EVENTS if event.endswith("click")) RAW_KEY_EVENTS = ( "press", diff --git a/openadapt/config.defaults.json b/openadapt/config.defaults.json index a74f15eb9..ef1d15608 100644 --- a/openadapt/config.defaults.json +++ b/openadapt/config.defaults.json @@ -18,6 +18,7 @@ "REPLAY_STRIP_ELEMENT_STATE": true, "RECORD_VIDEO": true, "RECORD_AUDIO": true, + "RECORD_BROWSER_EVENTS": false, "RECORD_FULL_VIDEO": false, "RECORD_IMAGES": false, "LOG_MEMORY": false, @@ -83,6 +84,8 @@ "SPACY_MODEL_NAME": "en_core_web_trf", "DASHBOARD_CLIENT_PORT": 5173, "DASHBOARD_SERVER_PORT": 8080, + "BROWSER_WEBSOCKET_PORT": 8765, + "BROWSER_WEBSOCKET_SERVER_IP": "localhost", "UNIQUE_USER_ID": "", "REDIRECT_TO_ONBOARDING": true } diff --git a/openadapt/config.py b/openadapt/config.py index d0a48301a..432c44e5f 100644 --- a/openadapt/config.py +++ b/openadapt/config.py @@ -28,6 +28,7 @@ PERFORMANCE_PLOTS_DIR_PATH = (DATA_DIR_PATH / "performance").absolute() CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute() VIDEO_DIR_PATH = DATA_DIR_PATH / "videos" +DATABASE_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute() DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock" STOP_STRS = [ @@ -120,7 +121,7 @@ class SegmentationAdapter(str, Enum): # Database DB_ECHO: bool = False - DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}" + DB_URL: ClassVar[str] = f"sqlite:///{DATABASE_FILE_PATH}" # Error reporting ERROR_REPORTING_ENABLED: bool = True @@ -133,10 +134,12 @@ class SegmentationAdapter(str, Enum): OPENAI_MODEL_NAME: str = "gpt-3.5-turbo" # Record and replay + EVENT_BUFFER_QUEUE_SIZE: int = 100 RECORD_WINDOW_DATA: bool = True - RECORD_READ_ACTIVE_ELEMENT_STATE: bool = False + RECORD_READ_ACTIVE_ELEMENT_STATE: bool RECORD_VIDEO: bool RECORD_AUDIO: bool + RECORD_BROWSER_EVENTS: bool # if false, only write video events corresponding to screenshots RECORD_FULL_VIDEO: bool RECORD_IMAGES: bool @@ -151,6 +154,11 @@ class SegmentationAdapter(str, Enum): list(stop_str) for stop_str in STOP_STRS ] + SPECIAL_CHAR_STOP_SEQUENCES + # Browser Events Record (extension) configurations + BROWSER_WEBSOCKET_SERVER_IP: str = "localhost" + BROWSER_WEBSOCKET_PORT: int = 8765 + BROWSER_WEBSOCKET_MAX_SIZE: int = 2**22 # 4MB + # Warning suppression IGNORE_WARNINGS: bool = False MAX_NUM_WARNINGS_PER_SECOND: int = 5 @@ -165,6 +173,9 @@ class SegmentationAdapter(str, Enum): # Performance plotting PLOT_PERFORMANCE: bool = True + # Database File Path + DATABASE_FILE_PATH: str = str(DATABASE_FILE_PATH) + # App configurations APP_DARK_MODE: bool = False @@ -282,6 +293,7 @@ def __setattr__(self, key: str, value: Any) -> None: "RECORD_READ_ACTIVE_ELEMENT_STATE", "RECORD_VIDEO", "RECORD_IMAGES", + "RECORD_BROWSER_EVENTS", "VIDEO_PIXEL_FORMAT", ], "general": [ @@ -395,7 +407,7 @@ def maybe_obfuscate(key: str, val: Any) -> Any: OBFUSCATE_KEY_PARTS = ("KEY", "PASSWORD", "TOKEN") parts = key.split("_") if any([part in parts for part in OBFUSCATE_KEY_PARTS]): - val = obfuscate(val) + val = obfuscate(str(val)) return val diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 806ff5ead..b0a9ff12d 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -20,6 +20,7 @@ from openadapt.db.db import Session, get_read_only_session_maker from openadapt.models import ( ActionEvent, + BrowserEvent, AudioInfo, MemoryStat, PerformanceStat, @@ -38,6 +39,7 @@ action_events = [] screenshots = [] window_events = [] +browser_events = [] performance_stats = [] memory_stats = [] @@ -153,6 +155,29 @@ def insert_window_event( _insert(session, event_data, WindowEvent, window_events) +def insert_browser_event( + session: SaSession, + recording: Recording, + event_timestamp: int, + event_data: dict[str, Any], +) -> None: + """Insert a browser event into the database. + + Args: + session (sa.orm.Session): The database session. + recording (Recording): The recording object. + event_timestamp (int): The timestamp of the event. + event_data (dict): The data of the event. + """ + event_data = { + **event_data, + "timestamp": event_timestamp, + "recording_id": recording.id, + "recording_timestamp": recording.timestamp, + } + _insert(session, event_data, BrowserEvent, browser_events) + + def insert_perf_stat( session: SaSession, recording: Recording, @@ -590,6 +615,27 @@ def get_window_events( ) +def get_browser_events(session: SaSession, recording: Recording) -> list[BrowserEvent]: + """Get browser events for a given recording. + + Args: + session (sa.orm.Session): The database session + recording (Recording): recording object + Returns: + List[BrowserEvent]: list of browser events + """ + return ( + session.query(BrowserEvent) + .filter(BrowserEvent.recording_id == recording.id) + .options( + joinedload(BrowserEvent.recording), + subqueryload(BrowserEvent.action_events).joinedload(ActionEvent.screenshot), + ) + .order_by(BrowserEvent.timestamp) + .all() + ) + + def disable_action_event(session: SaSession, event_id: int) -> None: """Disable an action event. @@ -611,12 +657,15 @@ def disable_action_event(session: SaSession, event_id: int) -> None: def get_new_session( read_only: bool = False, read_and_write: bool = False, + allow_add_on_read_only: bool = True, ) -> sa.orm.Session: """Get a new database session. Args: read_only (bool): Whether to open the session in read-only mode. read_and_write (bool): Whether to open the session in read-and-write mode. + allow_add_on_read_only (bool): Whether to allow session.add on read_only + (write to memory, but not to disk). Returns: sa.orm.Session: A new database session. @@ -632,7 +681,8 @@ def raise_error_on_write(*args: Any, **kwargs: Any) -> None: """Raise an error when trying to write to a read-only session.""" raise PermissionError("This session is read-only.") - session.add = raise_error_on_write + if not allow_add_on_read_only: + session.add = raise_error_on_write session.delete = raise_error_on_write session.commit = raise_error_on_write session.flush = raise_error_on_write @@ -734,6 +784,7 @@ def post_process_events(session: SaSession, recording: Recording) -> None: screenshots = _get(session, Screenshot, recording.id) action_events = _get(session, ActionEvent, recording.id) window_events = _get(session, WindowEvent, recording.id) + browser_events = _get(session, BrowserEvent, recording.id) screenshot_timestamp_to_id_map = { screenshot.timestamp: screenshot.id for screenshot in screenshots @@ -741,6 +792,9 @@ def post_process_events(session: SaSession, recording: Recording) -> None: window_event_timestamp_to_id_map = { window_event.timestamp: window_event.id for window_event in window_events } + browser_event_timestamp_to_id_map = { + browser_event.timestamp: browser_event.id for browser_event in browser_events + } for action_event in action_events: action_event.screenshot_id = screenshot_timestamp_to_id_map.get( @@ -749,6 +803,9 @@ def post_process_events(session: SaSession, recording: Recording) -> None: action_event.window_event_id = window_event_timestamp_to_id_map.get( action_event.window_event_timestamp ) + action_event.browser_event_id = browser_event_timestamp_to_id_map.get( + action_event.browser_event_timestamp + ) session.commit() @@ -789,6 +846,7 @@ def copy_action_event( screenshots = [action_event.screenshot for action_event in action_events] window_events = [action_event.window_event for action_event in action_events] + browser_events = [action_event.browser_event for action_event in action_events] for i, action_event in enumerate(new_action_events): action_event.screenshot = copy_sa_instance( @@ -797,6 +855,10 @@ def copy_action_event( action_event.window_event = copy_sa_instance( window_events[i], recording_id=new_recording.id ) + action_event.browser_event = copy_sa_instance( + browser_events[i], recording_id=new_recording.id + ) + session.add(action_event) session.commit() diff --git a/openadapt/db/db.py b/openadapt/db/db.py index e9191f7f0..dfdfb43ac 100644 --- a/openadapt/db/db.py +++ b/openadapt/db/db.py @@ -37,10 +37,12 @@ def __repr__(self) -> str: # avoid circular import from openadapt.utils import EMPTY, row2dict + ignore_attrs = getattr(self, "_repr_ignore_attrs", []) + params = ", ".join( f"{k}={v!r}" # !r converts value to string using repr (adds quotes) for k, v in row2dict(self, follow=False).items() - if v not in EMPTY + if v not in EMPTY and k not in ignore_attrs ) return f"{self.__class__.__name__}({params})" diff --git a/openadapt/events.py b/openadapt/events.py index 6384246c3..5866a3656 100644 --- a/openadapt/events.py +++ b/openadapt/events.py @@ -7,7 +7,7 @@ from scipy.spatial import distance import numpy as np -from openadapt import common, models, utils +from openadapt import browser, common, models, utils from openadapt.custom_logger import logger from openadapt.db import crud @@ -45,8 +45,12 @@ def get_events( start_time = time.time() action_events = crud.get_action_events(db, recording) window_events = crud.get_window_events(db, recording) + browser_events = crud.get_browser_events(db, recording) screenshots = crud.get_screenshots(db, recording) + browser_stats = browser.assign_browser_events(db, action_events, browser_events) + browser.log_stats(browser_stats) + if recording.original_recording_id: # if recording is a copy, it already has its events processed when it # was created, return only the top level events @@ -62,10 +66,12 @@ def get_events( assert num_action_events > 0, "No action events found." num_window_events = len(window_events) num_screenshots = len(screenshots) + num_browser_events = len(browser_events) num_action_events_raw = num_action_events num_window_events_raw = num_window_events num_screenshots_raw = num_screenshots + num_browser_events_raw = num_browser_events duration_raw = action_events[-1].timestamp - action_events[0].timestamp num_process_iters = 0 @@ -76,33 +82,38 @@ def get_events( f"{num_action_events=} " f"{num_window_events=} " f"{num_screenshots=}" + f"{num_browser_events=}" ) ( action_events, window_events, screenshots, - ) = process_events( + browser_events, + ) = merge_events( action_events, window_events, screenshots, + browser_events, ) if ( len(action_events) == num_action_events and len(window_events) == num_window_events and len(screenshots) == num_screenshots + and len(browser_events) == num_browser_events ): break num_process_iters += 1 num_action_events = len(action_events) num_window_events = len(window_events) num_screenshots = len(screenshots) + num_browser_events = len(browser_events) if num_process_iters == MAX_PROCESS_ITERS: break if meta is not None: - format_num = ( - lambda num, raw_num: f"{num} of {raw_num} ({(num / raw_num):.2%})" - ) # noqa: E731 + format_num = lambda num, raw_num: ( # noqa: E731 + f"{num} of {raw_num} ({(num / raw_num):.2%})" if raw_num else "0" + ) meta["num_process_iters"] = num_process_iters meta["num_action_events"] = format_num( num_action_events, @@ -116,6 +127,10 @@ def get_events( num_screenshots, num_screenshots_raw, ) + meta["num_browser_events"] = format_num( + num_browser_events, + num_browser_events_raw, + ) duration = action_events[-1].timestamp - action_events[0].timestamp if len(action_events) > 1: @@ -129,7 +144,7 @@ def get_events( event="get_events.completed", properties={"recording_id": recording.id} ) - return action_events # , window_events, screenshots + return action_events # , window_events, screenshots, browser_events def make_parent_event( @@ -152,13 +167,23 @@ def make_parent_event( "recording_timestamp": child.recording_timestamp, "window_event_timestamp": child.window_event_timestamp, "screenshot_timestamp": child.screenshot_timestamp, + "browser_event_timestamp": child.browser_event_timestamp, "recording": child.recording, "window_event": child.window_event, "screenshot": child.screenshot, + "browser_event": child.browser_event, } extra = extra or {} for key, val in extra.items(): event_dict[key] = val + + children = extra.get("children", []) + browser_events = [child.browser_event for child in children if child.browser_event] + if browser_events: + assert len(browser_events) <= 1, len(browser_events) + browser_event = browser_events[0] + event_dict["browser_event"] = browser_event + action_event = models.ActionEvent(**event_dict) return action_event @@ -793,16 +818,17 @@ def discard_unused_events( return referred_events -def process_events( +def merge_events( action_events: list[models.ActionEvent], window_events: list[models.WindowEvent], screenshots: list[models.Screenshot], + browser_events: list[models.BrowserEvent], ) -> tuple[ list[models.ActionEvent], list[models.WindowEvent], list[models.Screenshot], ]: - """Process action events, window events, and screenshots. + """Merge redundant action events, window events, and screenshots. Args: action_events (list): The list of action events. @@ -813,17 +839,17 @@ def process_events( tuple: A tuple containing the processed action events, window events, and screenshots. """ - # For debugging - # _action_events = action_events - # _window_events = window_events - # _screenshots = screenshots - num_action_events = len(action_events) num_window_events = len(window_events) num_screenshots = len(screenshots) - num_total = num_action_events + num_window_events + num_screenshots + num_browser_events = len(browser_events) + num_total = ( + num_action_events + num_window_events + num_screenshots + num_browser_events + ) logger.info( - f"before {num_action_events=} {num_window_events=} {num_screenshots=} " + "before" + f" {num_action_events=} {num_window_events=}" + f" {num_screenshots=} {num_browser_events=} " f"{num_total=}" ) process_fns = [ @@ -862,19 +888,33 @@ def process_events( action_events, "screenshot_timestamp", ) + browser_events = discard_unused_events( + browser_events, + action_events, + "browser_event_timestamp", + ) num_action_events_ = len(action_events) num_window_events_ = len(window_events) num_screenshots_ = len(screenshots) - num_total_ = num_action_events_ + num_window_events_ + num_screenshots_ + num_browser_events_ = len(browser_events) + num_total_ = ( + num_action_events_ + num_window_events_ + num_screenshots_ + num_browser_events_ + ) pct_action_events = num_action_events_ / num_action_events pct_window_events = num_window_events_ / num_window_events pct_screenshots = num_screenshots_ / num_screenshots + pct_browser_events = ( + num_browser_events_ / num_browser_events if num_browser_events else None + ) pct_total = num_total_ / num_total logger.info( - f"after {num_action_events_=} {num_window_events_=} {num_screenshots_=} " - f"{num_total_=}" + "after" + f" {num_action_events_=} {num_window_events_=}" + f" {num_screenshots_=} {num_browser_events_=}" + f" {num_total_=}" ) logger.info( - f"{pct_action_events=} {pct_window_events=} {pct_screenshots=} {pct_total=}" + f"{pct_action_events=} {pct_window_events=} {pct_screenshots=}" + f" {pct_browser_events=} {pct_total=}" ) - return action_events, window_events, screenshots + return action_events, window_events, screenshots, browser_events diff --git a/openadapt/extensions/synchronized_queue.py b/openadapt/extensions/synchronized_queue.py index 238db8b42..9614cdf27 100644 --- a/openadapt/extensions/synchronized_queue.py +++ b/openadapt/extensions/synchronized_queue.py @@ -8,7 +8,7 @@ # Credit: https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9 # The following implementation of custom SynchronizedQueue to avoid NotImplementedError -# when calling queue.qsize() in MacOS X comes almost entirely from this github +# when calling queue.qsize() in MacOS comes almost entirely from this github # discussion: https://github.com/keras-team/autokeras/issues/368 # Necessary modification is made to make the code compatible with Python3. @@ -50,7 +50,7 @@ class SynchronizedQueue(Queue): """A portable implementation of multiprocessing.Queue. Because of multithreading / multiprocessing semantics, Queue.qsize() may - raise the NotImplementedError exception on Unix platforms like Mac OS X + raise the NotImplementedError exception on Unix platforms like Mac OS where sem_getvalue() is not implemented. This subclass addresses this problem by using a synchronized shared counter (initialized to zero) and increasing / decreasing its value every time the put() and get() methods diff --git a/openadapt/models.py b/openadapt/models.py index 8f0e31e26..1df82c45e 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -4,6 +4,7 @@ from copy import deepcopy from itertools import zip_longest from typing import Any, Type +import copy import io import sys @@ -43,6 +44,7 @@ class Recording(db.Base): """Class representing a recording in the database.""" __tablename__ = "recording" + _repr_ignore_attrs = ["config"] id = sa.Column(sa.Integer, primary_key=True) timestamp = sa.Column(ForceFloat) @@ -83,6 +85,12 @@ class Recording(db.Base): order_by="WindowEvent.timestamp", cascade="all, delete-orphan", ) + browser_events = sa.orm.relationship( + "BrowserEvent", + back_populates="recording", + order_by="BrowserEvent.timestamp", + cascade="all, delete-orphan", + ) scrubbed_recordings = sa.orm.relationship( "ScrubbedRecording", back_populates="recording", cascade="all, delete-orphan" ) @@ -128,6 +136,8 @@ class ActionEvent(db.Base): screenshot_id = sa.Column(sa.ForeignKey("screenshot.id")) window_event_timestamp = sa.Column(ForceFloat) window_event_id = sa.Column(sa.ForeignKey("window_event.id")) + browser_event_timestamp = sa.Column(ForceFloat) + browser_event_id = sa.Column(sa.ForeignKey("browser_event.id")) mouse_x = sa.Column(sa.Numeric(asdecimal=False)) mouse_y = sa.Column(sa.Numeric(asdecimal=False)) mouse_dx = sa.Column(sa.Numeric(asdecimal=False)) @@ -213,6 +223,7 @@ def available_segment_descriptions(self, value: list[str]) -> None: recording = sa.orm.relationship("Recording", back_populates="action_events") screenshot = sa.orm.relationship("Screenshot", back_populates="action_event") window_event = sa.orm.relationship("WindowEvent", back_populates="action_events") + browser_event = sa.orm.relationship("BrowserEvent", back_populates="action_events") # TODO: playback_timestamp / original_timestamp @@ -256,11 +267,14 @@ def canonical_key(self) -> keyboard.Key | keyboard.KeyCode | str | None: self.canonical_key_vk, ) - def _text(self, canonical: bool = False) -> str | None: + def _text( + self, + canonical: bool = False, + name_prefix: str = config.ACTION_TEXT_NAME_PREFIX, + name_suffix: str = config.ACTION_TEXT_NAME_SUFFIX, + ) -> str | None: """Helper method to generate the text representation of the action event.""" sep = config.ACTION_TEXT_SEP - name_prefix = config.ACTION_TEXT_NAME_PREFIX - name_suffix = config.ACTION_TEXT_NAME_SUFFIX if self.children: parts = [ child.canonical_text if canonical else child.text @@ -616,6 +630,51 @@ def to_prompt_dict( return window_dict +class BrowserEvent(db.Base): + """Class representing a browser event in the database.""" + + __tablename__ = "browser_event" + + id = sa.Column(sa.Integer, primary_key=True) + recording_timestamp = sa.Column(ForceFloat) + recording_id = sa.Column(sa.ForeignKey("recording.id")) + message = sa.Column(sa.JSON) + timestamp = sa.Column(ForceFloat) + + recording = sa.orm.relationship("Recording", back_populates="browser_events") + action_events = sa.orm.relationship("ActionEvent", back_populates="browser_event") + + def __str__(self) -> str: + """Returns a truncated string representation without modifying original data.""" + # Create a copy of the message to avoid modifying the original + message_copy = copy.deepcopy(self.message) + + # Truncate the visibleHtmlString in the copied message if it exists + if "visibleHtmlString" in message_copy: + message_copy["visibleHtmlString"] = utils.truncate_html( + message_copy["visibleHtmlString"], max_len=100 + ) + + # Get all attributes except 'message' + attributes = { + attr: getattr(self, attr) + for attr in self.__mapper__.columns.keys() + if attr != "message" + } + + # Construct the string representation dynamically + base_repr = ", ".join(f"{key}={value}" for key, value in attributes.items()) + + # Return the complete representation including the truncated message + return f"BrowserEvent({base_repr}, message={message_copy})" + + # # TODO: implement + # @classmethod + # def get_active_browser_event( + # cls: "BrowserEvent", + # ) -> "BrowserEvent": + + class FrameCache: """Provide a caching mechanism for video frames to minimize IO operations. diff --git a/openadapt/record.py b/openadapt/record.py index e9a6061cb..27eb9e578 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Callable import io +import json import multiprocessing import os import queue @@ -35,6 +36,7 @@ import psutil import sounddevice import soundfile +import websockets.sync.server import whisper from openadapt import plotting, utils, video, window @@ -45,7 +47,7 @@ Event = namedtuple("Event", ("timestamp", "type", "data")) -EVENT_TYPES = ("screen", "action", "window") +EVENT_TYPES = ("screen", "action", "window", "browser") LOG_LEVEL = "INFO" # whether to write events of each type in a separate process PROC_WRITE_BY_EVENT_TYPE = { @@ -53,12 +55,14 @@ "screen/video": True, "action": True, "window": True, + "browser": True, } PLOT_PERFORMANCE = config.PLOT_PERFORMANCE NUM_MEMORY_STATS_TO_LOG = 3 STOP_SEQUENCES = config.STOP_SEQUENCES stop_sequence_detected = False +ws_server_instance = None def collect_stats(performance_snapshots: list[tracemalloc.Snapshot]) -> None: @@ -128,6 +132,7 @@ def process_events( screen_write_q: sq.SynchronizedQueue, action_write_q: sq.SynchronizedQueue, window_write_q: sq.SynchronizedQueue, + browser_write_q: sq.SynchronizedQueue, video_write_q: sq.SynchronizedQueue, perf_q: sq.SynchronizedQueue, recording: Recording, @@ -136,6 +141,7 @@ def process_events( num_screen_events: multiprocessing.Value, num_action_events: multiprocessing.Value, num_window_events: multiprocessing.Value, + num_browser_events: multiprocessing.Value, num_video_events: multiprocessing.Value, ) -> None: """Process events from the event queue and write them to write queues. @@ -145,6 +151,7 @@ def process_events( screen_write_q: A queue for writing screen events. action_write_q: A queue for writing action events. window_write_q: A queue for writing window events. + browser_write_q: A queue for writing browser events, video_write_q: A queue for writing video events. perf_q: A queue for collecting performance data. recording: The recording object. @@ -153,6 +160,7 @@ def process_events( num_screen_events: A counter for the number of screen events. num_action_events: A counter for the number of action events. num_window_events: A counter for the number of window events. + num_browser_events: A counter for the number of browser events. num_video_events: A counter for the number of video events. """ utils.set_start_time(recording.timestamp) @@ -179,8 +187,11 @@ def process_events( event, prev_event, ) - except AssertionError as exc: - logger.error(exc) + except AssertionError: + delta = event.timestamp - prev_event.timestamp + log_prev_event = prev_event._replace(data="") + log_event = event._replace(data="") + logger.error(f"{delta=} {log_prev_event=} {log_event=}") # behavior undefined, swallow for now # XXX TODO: mitigate if event.type == "screen": @@ -197,15 +208,28 @@ def process_events( num_video_events.value += 1 elif event.type == "window": prev_window_event = event + elif event.type == "browser": + if config.RECORD_BROWSER_EVENTS: + process_event( + event, + browser_write_q, + write_browser_event, + recording, + perf_q, + ) elif event.type == "action": if prev_screen_event is None: logger.warning("Discarding action that came before screen") continue + else: + event.data["screenshot_timestamp"] = prev_screen_event.timestamp + if prev_window_event is None: - logger.warning("Discarding input that came before window") + logger.warning("Discarding action that came before window") continue - event.data["screenshot_timestamp"] = prev_screen_event.timestamp - event.data["window_event_timestamp"] = prev_window_event.timestamp + else: + event.data["window_event_timestamp"] = prev_window_event.timestamp + process_event( event, action_write_q, @@ -213,7 +237,9 @@ def process_events( recording, perf_q, ) + num_action_events.value += 1 + if prev_saved_screen_timestamp < prev_screen_event.timestamp: process_event( prev_screen_event, @@ -316,6 +342,25 @@ def write_window_event( perf_q.put((event.type, event.timestamp, utils.get_timestamp())) +def write_browser_event( + db: crud.SaSession, + recording: Recording, + event: Event, + perf_q: sq.SynchronizedQueue, +) -> None: + """Write a browser event to the database and update the performance queue. + + Args: + db: The database session. + recording: The recording object. + event: A browser event to be written. + perf_q: A queue for collecting performance data. + """ + assert event.type == "browser", event + crud.insert_browser_event(db, recording, event.timestamp, event.data) + perf_q.put((event.type, event.timestamp, utils.get_timestamp())) + + @utils.trace(logger) def write_events( event_type: str, @@ -1126,6 +1171,103 @@ def audio_callback( ) +@logger.catch +@utils.trace(logger) +def read_browser_events( + websocket: websockets.sync.server.ServerConnection, + event_q: queue.Queue, + terminate_processing: Event, + recording: Recording, +) -> None: + """Read browser events and add them to the event queue. + + Params: + websocket: The websocket object. + event_q: A queue for adding browser events. + terminate_processing: An event to signal the termination of the process. + recording: The recording object. + + Returns: + None + """ + utils.set_start_time(recording.timestamp) + + logger.info("Starting Reading Browser Events ...") + + while not terminate_processing.is_set(): + for message in websocket: + if not message: + continue + + timestamp = utils.get_timestamp() + + data = json.loads(message) + + event_q.put( + Event( + timestamp, + "browser", + {"message": data}, + ) + ) + + +@logger.catch +@utils.trace(logger) +def run_browser_event_server( + event_q: queue.Queue, + terminate_processing: Event, + recording: Recording, + started_counter: multiprocessing.Value, +) -> None: + """Run the browser event server. + + Params: + event_q: A queue for adding browser events. + terminate_processing: An event to signal the termination of the process. + recording: The recording object. + started_counter: Value to increment once started. + + Returns: + None + """ + global ws_server_instance + + # Function to run the server in a separate thread + def run_server() -> None: + global ws_server_instance + with websockets.sync.server.serve( + lambda ws: read_browser_events( + ws, + event_q, + terminate_processing, + recording, + ), + config.BROWSER_WEBSOCKET_SERVER_IP, + config.BROWSER_WEBSOCKET_PORT, + max_size=config.BROWSER_WEBSOCKET_MAX_SIZE, + ) as server: + ws_server_instance = server + logger.info("WebSocket server started") + with started_counter.get_lock(): + started_counter.value += 1 + server.serve_forever() + + # Start the server in a separate thread + server_thread = threading.Thread(target=run_server) + server_thread.start() + + # Wait for a termination signal + terminate_processing.wait() + logger.info("Termination signal received, shutting down server") + + if ws_server_instance: + ws_server_instance.shutdown() + + # Ensure the server thread is terminated cleanly + server_thread.join() + + @logger.catch @utils.trace(logger) def record( @@ -1139,7 +1281,7 @@ def record( status_pipe: multiprocessing.connection.Connection | None = None, log_memory: bool = config.LOG_MEMORY, ) -> None: - """Record Screenshots/ActionEvents/WindowEvents. + """Record Screenshots/ActionEvents/WindowEvents/BrowserEvents. Args: task_description: A text description of the task to be recorded. @@ -1175,41 +1317,55 @@ def record( screen_write_q = sq.SynchronizedQueue() action_write_q = sq.SynchronizedQueue() window_write_q = sq.SynchronizedQueue() + browser_write_q = sq.SynchronizedQueue() video_write_q = sq.SynchronizedQueue() # TODO: save write times to DB; display performance plot in visualize.py perf_q = sq.SynchronizedQueue() if terminate_processing is None: terminate_processing = multiprocessing.Event() started_counter = multiprocessing.Value("i", 0) - expected_starts = 9 + task_by_name = {} window_event_reader = threading.Thread( target=read_window_events, args=(event_q, terminate_processing, recording, started_counter), ) window_event_reader.start() + task_by_name["window_event_reader"] = window_event_reader + + if config.RECORD_BROWSER_EVENTS: + browser_event_reader = threading.Thread( + target=run_browser_event_server, + args=(event_q, terminate_processing, recording, started_counter), + ) + browser_event_reader.start() + task_by_name["browser_event_reader"] = browser_event_reader screen_event_reader = threading.Thread( target=read_screen_events, args=(event_q, terminate_processing, recording, started_counter), ) screen_event_reader.start() + task_by_name["screen_event_reader"] = screen_event_reader keyboard_event_reader = threading.Thread( target=read_keyboard_events, args=(event_q, terminate_processing, recording, started_counter), ) keyboard_event_reader.start() + task_by_name["keyboard_event_reader"] = keyboard_event_reader mouse_event_reader = threading.Thread( target=read_mouse_events, args=(event_q, terminate_processing, recording, started_counter), ) mouse_event_reader.start() + task_by_name["mouse_event_reader"] = mouse_event_reader num_action_events = multiprocessing.Value("i", 0) num_screen_events = multiprocessing.Value("i", 0) num_window_events = multiprocessing.Value("i", 0) + num_browser_events = multiprocessing.Value("i", 0) num_video_events = multiprocessing.Value("i", 0) event_processor = threading.Thread( @@ -1219,6 +1375,7 @@ def record( screen_write_q, action_write_q, window_write_q, + browser_write_q, video_write_q, perf_q, recording, @@ -1227,10 +1384,12 @@ def record( num_screen_events, num_action_events, num_window_events, + num_browser_events, num_video_events, ), ) event_processor.start() + task_by_name["event_processor"] = event_processor screen_event_writer = multiprocessing.Process( target=utils.WrapStdout(write_events), @@ -1246,6 +1405,24 @@ def record( ), ) screen_event_writer.start() + task_by_name["screen_event_writer"] = screen_event_writer + + if config.RECORD_BROWSER_EVENTS: + browser_event_writer = multiprocessing.Process( + target=write_events, + args=( + "browser", + write_browser_event, + browser_write_q, + num_browser_events, + perf_q, + recording, + terminate_processing, + started_counter, + ), + ) + browser_event_writer.start() + task_by_name["browser_event_writer"] = browser_event_writer action_event_writer = multiprocessing.Process( target=utils.WrapStdout(write_events), @@ -1261,6 +1438,7 @@ def record( ), ) action_event_writer.start() + task_by_name["action_event_writer"] = action_event_writer window_event_writer = multiprocessing.Process( target=utils.WrapStdout(write_events), @@ -1276,9 +1454,9 @@ def record( ), ) window_event_writer.start() + task_by_name["window_event_writer"] = window_event_writer if config.RECORD_VIDEO: - expected_starts += 1 video_writer = multiprocessing.Process( target=utils.WrapStdout(write_events), args=( @@ -1295,9 +1473,9 @@ def record( ), ) video_writer.start() + task_by_name["video_writer"] = video_writer if config.RECORD_AUDIO: - expected_starts += 1 audio_recorder = multiprocessing.Process( target=utils.WrapStdout(record_audio), args=( @@ -1307,9 +1485,10 @@ def record( ), ) audio_recorder.start() + task_by_name["audio_recorder"] = audio_recorder terminate_perf_event = multiprocessing.Event() - perf_stat_writer = multiprocessing.Process( + perf_stats_writer = multiprocessing.Process( target=utils.WrapStdout(performance_stats_writer), args=( perf_q, @@ -1318,12 +1497,12 @@ def record( started_counter, ), ) - perf_stat_writer.start() + perf_stats_writer.start() + task_by_name["perf_stats_writer"] = perf_stats_writer if PLOT_PERFORMANCE: - expected_starts += 1 record_pid = os.getpid() - mem_plotter = multiprocessing.Process( + mem_writer = multiprocessing.Process( target=utils.WrapStdout(memory_writer), args=( recording, @@ -1332,7 +1511,8 @@ def record( started_counter, ), ) - mem_plotter.start() + mem_writer.start() + task_by_name["mem_writer"] = mem_writer if log_memory: performance_snapshots = [] @@ -1343,22 +1523,23 @@ def record( # TODO: discard events until everything is ready # Wait for all to signal they've started + expected_starts = len(task_by_name) + logger.info(f"{expected_starts=}") while True: if started_counter.value >= expected_starts: break time.sleep(0.1) # Sleep to reduce busy waiting for _ in range(5): logger.info("*" * 40) + logger.info("All readers and writers have started. Waiting for input events...") + if status_pipe: status_pipe.send({"type": "record.started"}) - logger.info("All readers and writers have started. Waiting for input events...") global stop_sequence_detected - try: while not (stop_sequence_detected or terminate_processing.is_set()): time.sleep(1) - terminate_processing.set() except KeyboardInterrupt: terminate_processing.set() @@ -1370,23 +1551,39 @@ def record( collect_stats(performance_snapshots) log_memory_usage(_tracker, performance_snapshots) - logger.info("joining...") - keyboard_event_reader.join() - mouse_event_reader.join() - screen_event_reader.join() - window_event_reader.join() - event_processor.join() - screen_event_writer.join() - action_event_writer.join() - window_event_writer.join() - if config.RECORD_VIDEO: - video_writer.join() - if config.RECORD_AUDIO: - audio_recorder.join() + def join_tasks(task_names: list[str]) -> None: + for task_name in task_names: + if task_name in task_by_name: + logger.info(f"joining {task_name=}...") + task = task_by_name[task_name] + task.join() + + join_tasks( + [ + "window_event_reader", + "browser_event_reader", + "screen_event_reader", + "keyboard_event_reader", + "mouse_event_reader", + "event_processor", + "screen_event_writer", + "browser_event_writer", + "action_event_writer", + "window_event_writer", + "video_writer", + "audio_recorder", + ] + ) + terminate_perf_event.set() + join_tasks( + [ + "perf_stats_writer", + "mem_writer", + ] + ) if PLOT_PERFORMANCE: - mem_plotter.join() plotting.plot_performance(recording) logger.info(f"Saved {recording_timestamp=}") diff --git a/openadapt/scripts/reset_db.py b/openadapt/scripts/reset_db.py index 8bba91be2..a000fac27 100644 --- a/openadapt/scripts/reset_db.py +++ b/openadapt/scripts/reset_db.py @@ -14,8 +14,8 @@ def reset_db() -> None: """Clears the database by removing the db file and running a db migration.""" - if os.path.exists(config.DB_FPATH): - os.remove(config.DB_FPATH) + if os.path.exists(config.DATABASE_FILE_PATH): + os.remove(config.DATABASE_FILE_PATH) # Prevents duplicate logging of config values by piping stderr # and filtering the output. diff --git a/openadapt/utils.py b/openadapt/utils.py index dc2795f15..82279c0d4 100644 --- a/openadapt/utils.py +++ b/openadapt/utils.py @@ -969,6 +969,29 @@ def wrapper_retry(*args: tuple, **kwargs: dict[str, Any]) -> Any: return decorator_retry +def truncate_html(html_str: str, max_len: int) -> str: + """Truncates the given HTML string to a specified maximum length. + + Retains the head and tail while indicating the truncated portion in the middle. + + Args: + html_str (str): The HTML string to truncate. + max_len (int): The maximum length for the truncated HTML string. + + Returns: + str: The truncated HTML string with the head and tail retained, and + an indication of the truncated portion in the middle if applicable. + """ + if len(html_str) > max_len: + n = max_len // 2 + head = html_str[:n] + tail = html_str[-n:] + snipped = html_str[n:-n] + middle = f"
...(snipped {len(snipped):,})...
" + html_str = head + middle + tail + return html_str + + class WrapStdout: """Class to be used a target for multiprocessing.Process.""" diff --git a/openadapt/visualize.py b/openadapt/visualize.py index d5e0a4842..240362b14 100644 --- a/openadapt/visualize.py +++ b/openadapt/visualize.py @@ -32,6 +32,7 @@ image2utf8, row2dict, rows2dicts, + truncate_html, ) SCRUB = config.SCRUB_ENABLED @@ -142,13 +143,7 @@ def dict2html( html_str = f"{rows_html}
" else: html_str = html.escape(str(obj)) - if len(html_str) > max_len: - n = max_len // 2 - head = html_str[:n] - tail = html_str[-n:] - snipped = html_str[n:-n] - middle = f"
...(snipped {len(snipped):,})...
" - html_str = head + middle + tail + html_str = truncate_html(html_str, max_len) return html_str @@ -346,10 +341,12 @@ def main( action_event_dict = row2dict(action_event) window_event_dict = row2dict(action_event.window_event) + browser_event_dict = row2dict(action_event.browser_event) if SCRUB: action_event_dict = scrub.scrub_dict(action_event_dict) window_event_dict = scrub.scrub_dict(window_event_dict) + browser_event_dict = scrub.scrub_dict(browser_event_dict) rows.append( [ @@ -379,6 +376,9 @@ def main( {dict2html(window_event_dict , None)}
+ + {dict2html(browser_event_dict , None)} +
""", ), Div(text=f""" @@ -394,6 +394,35 @@ def main( progress.close() + # Visualize BrowserEvents + rows.append([row(Div(text="

Browser Events

"))]) + browser_events = crud.get_browser_events(session, recording) + with redirect_stdout_stderr(): + with tqdm( + total=len(browser_events), + desc="Preparing HTML (browser events)", + unit="event", + colour="green", + dynamic_ncols=True, + ) as progress: + for idx, browser_event in enumerate(browser_events): + browser_event_dict = row2dict(browser_event) + rows.append( + [ + row( + Div(text=f""" + + {dict2html(browser_event_dict)} +
+ """), + ), + ] + ) + + progress.update() + + progress.close() + title = f"recording-{recording.id}" fname_out = RECORDING_DIR_PATH / f"recording-{recording.id}.html" diff --git a/poetry.lock b/poetry.lock index ac5e3e77f..e83b8d488 100644 --- a/poetry.lock +++ b/poetry.lock @@ -427,6 +427,27 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "beautifulsoup4" +version = "4.12.3" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, + {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, +] + +[package.dependencies] +soupsieve = ">1.2" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "bidict" version = "0.23.1" @@ -1359,6 +1380,43 @@ files = [ {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, ] +[[package]] +name = "dtaidistance" +version = "2.3.12" +description = "Distance measures for time series (Dynamic Time Warping, fast C implementation)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "dtaidistance-2.3.12-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c11618383363d9625f2ae08a40658589023c088d558ec9d25f103d077c53f1a6"}, + {file = "dtaidistance-2.3.12-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:d61cdc5656be065ddbc2bab502ac2125a8c931ec076693d4986fecb46bf720b7"}, + {file = "dtaidistance-2.3.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5adf6d2be006afc1b56bd6319236b9a8f6a2f50243af06dd9325f4e09bc41b4"}, + {file = "dtaidistance-2.3.12-cp310-cp310-win_amd64.whl", hash = "sha256:ea8dd3f56becbb74fbf45239683dce17aa666ef8ccb078c03399d77fdb8994aa"}, + {file = "dtaidistance-2.3.12-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:881e7056d112f11ebf22b9bc57447220faa7c32690b35e818c94e2ddad170705"}, + {file = "dtaidistance-2.3.12-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4d3b45c269db4d2d855b8ecf242bdbca9a362fd811f50610a8ca236713a888d4"}, + {file = "dtaidistance-2.3.12-cp311-cp311-win_amd64.whl", hash = "sha256:182e1c0fca4fe994caf3798d32ad2c28c45d6303fca38d91816e88fe1ccbd83f"}, + {file = "dtaidistance-2.3.12-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b800db8e924e8c62e1e037aa52a731bd1c1e9421bf8baf0148fb1b304a490395"}, + {file = "dtaidistance-2.3.12-cp312-cp312-win_amd64.whl", hash = "sha256:b55d0a1ca980348e4ddb81bb6992a60d4e718d52714e3bd6e27cbf9dd55c505a"}, + {file = "dtaidistance-2.3.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a606043f86562d18476d570f040838b24e7c42506181e454d44df55b9421b4a6"}, + {file = "dtaidistance-2.3.12-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:79a9163748bda3b46e90a9634513c1ac5f157c1df2487f06ba951e3ddeef885d"}, + {file = "dtaidistance-2.3.12-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7b5db221aba2dee932ffae0b230c2ee015a9993cee0f5e3bb3dae5f188de46d0"}, + {file = "dtaidistance-2.3.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:502b1da5b5f6fa8d04730202839cf38e428f399b12cd7f1caf84047e5e9beb0d"}, + {file = "dtaidistance-2.3.12-cp38-cp38-win_amd64.whl", hash = "sha256:dbf7472eee3d4a4ae45951ef21c7b97b393c3f906e77b8a19aaffd79e418d440"}, + {file = "dtaidistance-2.3.12-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9008074bb3c1ccfbf9149924630a2e6cf57466c69241766bd89dbdeb1f9c3da6"}, + {file = "dtaidistance-2.3.12-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:df471c2cd9ee7244e1810d5b8ee2cb301af5b142cd4988fe15f3e9aa15795537"}, + {file = "dtaidistance-2.3.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fe309b5d693dc068c67995722b89f3cecf845e8642c2d060b8f4b1773db2542"}, + {file = "dtaidistance-2.3.12-cp39-cp39-win_amd64.whl", hash = "sha256:0bc99ba6b33d7c5ca460c95fb1528c6c910d858effd64fa41051c35ebb70ae8f"}, + {file = "dtaidistance-2.3.12.tar.gz", hash = "sha256:f239f83783d92f9da3a9597a79d93e3d2f3fb81d972fd4703241f9bffe7dbb3d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +all = ["matplotlib (>=3.0.0)", "numpy", "scipy"] +dev = ["matplotlib (>=3.0.0)", "numpy", "pytest", "pytest-benchmark", "scipy", "sphinx", "sphinx-rtd-theme"] +numpy = ["numpy", "scipy"] +vis = ["matplotlib (>=3.0.0)"] + [[package]] name = "easyocr" version = "1.7.1" @@ -6531,6 +6589,17 @@ cffi = ">=1.0" [package.extras] numpy = ["numpy"] +[[package]] +name = "soupsieve" +version = "2.6" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +files = [ + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, +] + [[package]] name = "spacy" version = "3.7.4" @@ -8292,4 +8361,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.10.x" -content-hash = "f70b99e25a2d0106b3e4f6af7fc126a003432b334b0307b1d2e67ad08f70c0f1" +content-hash = "906630b6f2aa9fa40caa1d957967fea3c7432cb8dcc83762c22a520b0848fafe" diff --git a/pyproject.toml b/pyproject.toml index e236085f6..4a2c9ce4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,8 @@ posthog = "^3.5.0" wheel = "^0.43.0" cython = "^3.0.10" av = "^12.3.0" +beautifulsoup4 = "^4.12.3" +dtaidistance = "^2.3.12" [tool.pytest.ini_options] filterwarnings = [ # suppress warnings starting from "setuptools>=67.3" diff --git a/tests/openadapt/test_browser.py b/tests/openadapt/test_browser.py new file mode 100644 index 000000000..64ef6d5f8 --- /dev/null +++ b/tests/openadapt/test_browser.py @@ -0,0 +1,196 @@ +"""Test openadapt.browser module.""" + +import pytest +from unittest.mock import MagicMock +from openadapt.models import ActionEvent, BrowserEvent +from openadapt.browser import assign_browser_events, fit_linear_transformation + + +def generate_coord_mappings( + client_start: float, + client_end: float, + screen_start: float, + screen_end: float, + steps: int = 2, +) -> dict: + """Generates coordinate mappings for client and screen coordinates.""" + client_coords = [ + client_start + i * (client_end - client_start) / (steps - 1) + for i in range(steps) + ] + screen_coords = [ + screen_start + i * (screen_end - screen_start) / (steps - 1) + for i in range(steps) + ] + return {"client": client_coords, "screen": screen_coords} + + +def generate_tlbr_from_coords(client_coords: dict, screen_coords: dict) -> (str, str): + """Generates top/left/bottom/right for given client/screen mappings.""" + client_top, client_bottom = min(client_coords["client"]), max( + client_coords["client"] + ) + client_left, client_right = min(screen_coords["client"]), max( + screen_coords["client"] + ) + + screen_top, screen_bottom = min(client_coords["screen"]), max( + client_coords["screen"] + ) + screen_left, screen_right = min(screen_coords["screen"]), max( + screen_coords["screen"] + ) + + data_tlbr_client = f"{client_top},{client_left},{client_bottom},{client_right}" + data_tlbr_screen = f"{screen_top},{screen_left},{screen_bottom},{screen_right}" + return data_tlbr_client, data_tlbr_screen + + +def compute_screen_coords( + client_coords: str, x_mappings: dict, y_mappings: dict +) -> str: + """Computes the screen coordinates using the provided coordinates and mappings.""" + client_top, client_left, client_bottom, client_right = map( + float, client_coords.split(",") + ) + + # Compute scales and offsets + sx_scale, sx_offset = fit_linear_transformation( + x_mappings["client"], x_mappings["screen"] + ) + sy_scale, sy_offset = fit_linear_transformation( + y_mappings["client"], y_mappings["screen"] + ) + + # Calculate screen coordinates using the computed scale and offset + screen_top = sy_scale * client_top + sy_offset + screen_left = sx_scale * client_left + sx_offset + screen_bottom = sy_scale * client_bottom + sy_offset + screen_right = sx_scale * client_right + sx_offset + + return f"{screen_top},{screen_left},{screen_bottom},{screen_right}" + + +@pytest.fixture +def mock_session() -> MagicMock: + """Creates a mock SQLAlchemy session.""" + return MagicMock() + + +@pytest.fixture +def fake_action_events() -> list[ActionEvent]: + """Creates a list of fake ActionEvent instances.""" + return [ + ActionEvent( + id=1, + name="click", + timestamp=1.0, + mouse_x=100.0, + mouse_y=200.0, + mouse_button_name="left", + ), + ActionEvent(id=2, name="press", timestamp=2.0, key_char="a"), + ActionEvent(id=3, name="release", timestamp=3.0, key_char="a"), + ] + + +@pytest.fixture +def fake_browser_events() -> list[BrowserEvent]: + """Creates a list of fake BrowserEvent instances programmatically.""" + # Reusable coordinate values + clientX_start = 100.0 + clientX_end = 150.0 + clientY_start = 200.0 + clientY_end = 250.0 + screenX_start = 300.0 + screenX_end = 450.0 + screenY_start = 600.0 + screenY_end = 750.0 + + # Generate coordinate mappings + coord_mappings_x = generate_coord_mappings( + clientX_start, clientX_end, screenX_start, screenX_end + ) + coord_mappings_y = generate_coord_mappings( + clientY_start, clientY_end, screenY_start, screenY_end + ) + coord_mappings = {"x": coord_mappings_x, "y": coord_mappings_y} + + # Generate the bounding box strings dynamically + data_tlbr_client, _ = generate_tlbr_from_coords(coord_mappings_x, coord_mappings_y) + + # Compute the correct screen coordinates using the same logic as in browser.py + data_tlbr_screen = compute_screen_coords( + data_tlbr_client, coord_mappings_x, coord_mappings_y + ) + + # Generate the visible HTML string with dynamically calculated bounding boxes + visible_html_string = ( + f'
' + ) + + return [ + BrowserEvent( + id=1, + timestamp=1.0, + message={ + "type": "USER_EVENT", + "eventType": "click", + "button": 0, + "clientX": clientX_start, + "clientY": clientY_start, + "screenX": screenX_start, + "screenY": screenY_start, + "visibleHtmlString": visible_html_string, + "targetId": "1", + "coordMappings": coord_mappings, + "timestamp": 1.0, + }, + ), + BrowserEvent( + id=2, + timestamp=2.0, + message={ + "type": "USER_EVENT", + "eventType": "keydown", + "key": "a", + "timestamp": 2.0, + }, + ), + BrowserEvent( + id=3, + timestamp=3.0, + message={ + "type": "USER_EVENT", + "eventType": "keyup", + "key": "a", + "timestamp": 3.0, + }, + ), + ] + + +def test_assign_browser_events( + mock_session: MagicMock, + fake_action_events: list[ActionEvent], + fake_browser_events: list[BrowserEvent], +) -> None: + """Tests the assign_browser_events function with simulated events.""" + # Call the function with the fake data + assign_browser_events(mock_session, fake_action_events, fake_browser_events) + + # Inspect the assignments of ActionEvent instances + for action_event in fake_action_events: + if action_event.name == "click": + assert action_event.browser_event_id == 1 + assert action_event.browser_event_timestamp == 1.0 + elif action_event.name == "press": + assert action_event.browser_event_id == 2 + assert action_event.browser_event_timestamp == 2.0 + elif action_event.name == "release": + assert action_event.browser_event_id == 3 + assert action_event.browser_event_timestamp == 3.0 + + # Verify that the session's add method was called for each action event + assert mock_session.add.call_count == len(fake_action_events) diff --git a/tests/openadapt/test_crud.py b/tests/openadapt/test_crud.py index 4e3547ebf..4db1a3b66 100644 --- a/tests/openadapt/test_crud.py +++ b/tests/openadapt/test_crud.py @@ -31,8 +31,8 @@ def test_get_new_session_read_only(db_engine: sa.engine.Engine) -> None: platform="Windows", task_description="Task description", ) - with pytest.raises(PermissionError): - session.add(recording) + # with pytest.raises(PermissionError): + # session.add(recording) with pytest.raises(PermissionError): session.commit() with pytest.raises(PermissionError):