Skip to content

Commit

Permalink
Add uuid to ast dir generation in context retrieval to prevent race c…
Browse files Browse the repository at this point in the history
…ondition (#39)

Add uuid to AST directory generation to prevent race condition when
multiple experiments are run with the same project, causing a potential
race when one experiment is downloading and copying the ast files to the
project ast directory and when the other experiment is iterating over
the files within that ast directory.

Fixes #20
  • Loading branch information
trashvisor committed Jan 31, 2024
1 parent 1b6b583 commit 94ab393
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
16 changes: 11 additions & 5 deletions data_prep/project_context/context_retriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import re
import shutil
import subprocess
import uuid
from collections import defaultdict
from typing import List, Tuple

Expand All @@ -24,10 +26,11 @@ def __init__(self, project_name: str, function_signature: str):
self._project_name = project_name
self._function_signature = function_signature
self._download_from_path = f'{self.AST_BASE_PATH}/{self._project_name}/*'
self._ast_path = f'{self.DOWNLOAD_TO_PATH}/{self._project_name}'
self._uuid = uuid.uuid4()
self._ast_path = f'{self.DOWNLOAD_TO_PATH}/{self._project_name}-{self._uuid}'

def _get_function_name(self, target_function_signature: str) -> str:
"""Retrieve the function name from the target function signature."""
"""Retrieves the function name from the target function signature."""
# Grabs the function name by getting anything before '(' and then remove the type by grabbing any character after space.
target_function = target_function_signature.split('(')[0].split(' ')[-1]
# Removes possible pointer.
Expand Down Expand Up @@ -229,8 +232,7 @@ def _get_header_from_file(self, fully_qualified_path: str) -> str:
Retrieve that node's loc->file to get the file where the FunctionDecl exists."""
target_function = self._get_function_name(self._function_signature)

with open(f'{fully_qualified_path}') as ast_file:
print(f'Opening for...{fully_qualified_path}')
with open(fully_qualified_path) as ast_file:
ast_json = json.load(ast_file)
# AST nodes are all wrapped in an inner node.
ast_nodes = ast_json.get('inner', [])
Expand All @@ -250,7 +252,7 @@ def _get_header_from_file(self, fully_qualified_path: str) -> str:
return ''

def retrieve_asts(self):
"""Downloads ASTS for the given project."""
"""Downloads ASTs for the given project."""
os.makedirs(self._ast_path, exist_ok=True)

download_command = [
Expand All @@ -263,6 +265,10 @@ def retrieve_asts(self):
stderr=subprocess.PIPE,
)

def cleanup_asts(self):
"""Removes ASTs for the given project."""
shutil.rmtree(self._ast_path)

def generate_lookups(self):
"""Goes through all AST files downloaded.
Generates a lookup so that RecordDecl/TypedefDecl/EnumDecl nodes can be found by name."""
Expand Down
1 change: 1 addition & 0 deletions run_one_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def run(benchmark: Benchmark,
print(context_header)
context_types = '\n'.join(retriever.get_type_info())
context_info = (context_header, context_types)
retriever.cleanup_asts()

model.prompt_path = model.prepare_generate_prompt(
work_dirs.prompt,
Expand Down

0 comments on commit 94ab393

Please sign in to comment.