Skip to content

Commit

Permalink
move payload_list up, from Director instance scope to class scope, to…
Browse files Browse the repository at this point in the history
… reduce volume of payload scans
  • Loading branch information
leondz committed Oct 2, 2024
1 parent 40f0a79 commit 66ca95f
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions garak/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class Director:
manage enumeration of payloads (optionally given a payload type specification),
and load them up."""

payload_list = None

def _scan_payload_dir(self, dir) -> dict:
"""Look for .json entries in a dir, load them, check which are
payloads, return name:path dict. optionally filter by type prefixes"""
Expand Down Expand Up @@ -175,26 +177,30 @@ def _scan_payload_dir(self, dir) -> dict:
def _refresh_payloads(self) -> None:
"""Scan resources/payloads and the XDG_DATA_DIR/payloads for
payload objects, and refresh self.payload_list"""
self.payload_list = self._scan_payload_dir(PAYLOAD_DIR)
self.__class__.payload_list = self._scan_payload_dir(PAYLOAD_DIR)

def search(
self, types: Union[List[str], None] = None, include_children=True
) -> Generator[str, None, None]:
"""Return list of payload names, optionally filtered by types"""
for payload in self.payload_list:
for payload in self.__class__.payload_list:
if types is None:
yield payload
else:
if include_children is False:
matches = [
payload_type == type_prefix
for payload_type in self.payload_list[payload]["types"]
for payload_type in self.__class__.payload_list[payload][
"types"
]
for type_prefix in types
]
else:
matches = [
payload_type.startswith(type_prefix)
for payload_type in self.payload_list[payload]["types"]
for payload_type in self.__class__.payload_list[payload][
"types"
]
for type_prefix in types
]
if any(matches):
Expand All @@ -203,7 +209,7 @@ def search(
def load(self, name) -> PayloadGroup:
"""Return a PayloadGroup"""
try:
path = self.payload_list[name]["path"]
path = self.__class__.payload_list[name]["path"]
p = load_payload(name, path) # or raise KeyError

except KeyError as ke:
Expand All @@ -221,5 +227,5 @@ def load(self, name) -> PayloadGroup:
return p

def __init__(self) -> None:
self.payload_list = {} # name: {path:path, types:types}
self._refresh_payloads()
if self.__class__.payload_list is None:
self._refresh_payloads()

0 comments on commit 66ca95f

Please sign in to comment.