Skip to content

Commit

Permalink
Add --disable-cache flag
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Feb 6, 2024
1 parent 8772e2f commit a6fc79a
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 53 deletions.
4 changes: 3 additions & 1 deletion src/helm/benchmark/adaptation/adapters/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import shutil
import tempfile

from helm.common.authentication import Authentication
from helm.proxy.services.service import CACHE_DIR
from helm.proxy.services.server_service import ServerService
from helm.benchmark.window_services.tokenizer_service import TokenizerService

Expand All @@ -13,7 +15,7 @@ class TestAdapter:

def setup_method(self):
self.path: str = tempfile.mkdtemp()
service = ServerService(base_path=self.path, root_mode=True)
service = ServerService(base_path=self.path, root_mode=True, cache_path=os.path.join(self.path, CACHE_DIR))
self.tokenizer_service = TokenizerService(service, Authentication("test"))

def teardown_method(self, _):
Expand Down
14 changes: 8 additions & 6 deletions src/helm/benchmark/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class ExecutionSpec:
# Whether to skip execution
dry_run: bool = False

# URL to the MongoDB database.
# If non-empty, the MongoDB database will be used for caching instead of SQLite.
# Example format: mongodb://[username:password@]host1[:port1]/[dbname]
# For full format, see: https://www.mongodb.com/docs/manual/reference/connection-string/
mongo_uri: str = ""
# Where to store data for the cache.
# If None, the cache will be disabled.
# If a path to the directory, this specifies the directory in which the SQLite cache will store files.
# If a MongoDB URI starting with , this specifies the MongoDB database to be used by the MongoDB cache.
cache_path: Optional[str] = None


class Executor:
Expand All @@ -58,7 +58,9 @@ def __init__(self, execution_spec: ExecutionSpec):
elif execution_spec.local_path:
hlog(f"Running in local mode with base path: {execution_spec.local_path}")
self.service = ServerService(
base_path=execution_spec.local_path, root_mode=True, mongo_uri=execution_spec.mongo_uri
base_path=execution_spec.local_path,
root_mode=True,
cache_path=execution_spec.cache_path,
)
else:
raise ValueError("Either the proxy server URL or the local path must be set")
Expand Down
22 changes: 19 additions & 3 deletions src/helm/benchmark/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from dataclasses import replace
import os
from typing import List, Optional


Expand All @@ -9,6 +10,7 @@
from helm.common.authentication import Authentication
from helm.common.object_spec import parse_object_spec, get_class_by_name
from helm.proxy.services.remote_service import create_authentication, add_service_args
from helm.proxy.services.service import CACHE_DIR

from helm.benchmark.config_registry import (
register_configs_from_directory,
Expand Down Expand Up @@ -83,7 +85,7 @@ def run_benchmarking(
skip_completed_runs: bool,
exit_on_error: bool,
runner_class_name: Optional[str],
mongo_uri: str = "",
cache_path: Optional[str] = "",
) -> List[RunSpec]:
"""Runs RunSpecs given a list of RunSpec descriptions."""
execution_spec = ExecutionSpec(
Expand All @@ -92,7 +94,7 @@ def run_benchmarking(
local_path=local_path,
parallelism=num_threads,
dry_run=dry_run,
mongo_uri=mongo_uri,
cache_path=cache_path,
)
with htrack_block("run_specs"):
for run_spec in run_specs:
Expand Down Expand Up @@ -171,6 +173,11 @@ def add_run_args(parser: argparse.ArgumentParser):
help="If running locally, the path for `ServerService`.",
default="prod_env",
)
parser.add_argument(
"--disable-cache",
action="store_true",
help="If true, the request-response cache for model clients and tokenizers will be disabled.",
)
parser.add_argument(
"--mongo-uri",
type=str,
Expand Down Expand Up @@ -294,6 +301,15 @@ def main():
Authentication("") if args.skip_instances or not args.server_url else create_authentication(args)
)

cache_path: Optional[str]
if args.disable_cache:
cache_path = None
elif args.mongo_uri:
cache_path = args.mongo_uri
else:
cache_path = os.path.join(args.local_path, CACHE_DIR)
ensure_directory_exists(cache_path)

run_benchmarking(
run_specs=run_specs,
auth=auth,
Expand All @@ -309,7 +325,7 @@ def main():
skip_completed_runs=args.skip_completed_runs,
exit_on_error=args.exit_on_error,
runner_class_name=args.runner_class_name,
mongo_uri=args.mongo_uri,
cache_path=cache_path,
)

if args.local:
Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/test_model_deployment_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class TestModelProperties:
@pytest.mark.parametrize("deployment_name", [deployment.name for deployment in ALL_MODEL_DEPLOYMENTS])
def test_models_has_window_service(self, deployment_name: str):
with TemporaryDirectory() as tmpdir:
auto_client = AutoClient({}, tmpdir, "")
auto_tokenizer = AutoTokenizer({}, tmpdir, "")
auto_client = AutoClient({}, None)
auto_tokenizer = AutoTokenizer({}, None)
tokenizer_service = get_tokenizer_service(tmpdir)

# Loading the TokenizerConfig and ModelMetadat ensures that they are valid.
Expand Down
4 changes: 3 additions & 1 deletion src/helm/benchmark/window_services/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from typing import List

from helm.common.authentication import Authentication
from helm.proxy.services.server_service import ServerService
from helm.proxy.services.service import CACHE_DIR
from helm.benchmark.metrics.metric_service import MetricService
from .tokenizer_service import TokenizerService

Expand Down Expand Up @@ -228,5 +230,5 @@


def get_tokenizer_service(local_path: str) -> TokenizerService:
service = ServerService(base_path=local_path, root_mode=True)
service = ServerService(base_path=local_path, root_mode=True, cache_path=os.path.join(local_path, CACHE_DIR))
return MetricService(service, Authentication("test"))
14 changes: 13 additions & 1 deletion src/helm/common/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sqlite3

from helm.common.general import hlog, htrack
from helm.common.key_value_store import KeyValueStore, SqliteKeyValueStore
from helm.common.key_value_store import BlackHoleKeyValueStore, KeyValueStore, SqliteKeyValueStore
from helm.proxy.retry import get_retry_decorator

try:
Expand Down Expand Up @@ -51,6 +51,16 @@ def cache_stats_key(self) -> str:
return self.path


@dataclass(frozen=True)
class BlackHoleCacheConfig(KeyValueStoreCacheConfig):
"""Configuration for a cache that does not save any data."""

@property
def cache_stats_key(self) -> str:
"""The string key used by CacheStats to identify this cache."""
return "disabled_cache"


@dataclass(frozen=True)
class MongoCacheConfig(KeyValueStoreCacheConfig):
"""Configuration for a cache backed by a MongoDB collection."""
Expand Down Expand Up @@ -113,6 +123,8 @@ def create_key_value_store(config: KeyValueStoreCacheConfig) -> KeyValueStore:
return MongoKeyValueStore(config.uri, config.collection_name)
elif isinstance(config, SqliteCacheConfig):
return SqliteKeyValueStore(config.path)
elif isinstance(config, BlackHoleCacheConfig):
return BlackHoleKeyValueStore()
else:
raise ValueError(f"KeyValueStoreCacheConfig with unknown type: {config}")

Expand Down
17 changes: 9 additions & 8 deletions src/helm/common/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Functions used for caching."""

import os
from typing import Optional

from helm.common.cache import CacheConfig, MongoCacheConfig, SqliteCacheConfig
from helm.common.cache import BlackHoleCacheConfig, CacheConfig, MongoCacheConfig, SqliteCacheConfig


def build_cache_config(cache_path: str, mongo_uri: str, organization: str) -> CacheConfig:
if mongo_uri:
return MongoCacheConfig(mongo_uri, collection_name=organization)

client_cache_path: str = os.path.join(cache_path, f"{organization}.sqlite")
# TODO: Allow setting CacheConfig.follower_cache_path from a command line flag.
return SqliteCacheConfig(client_cache_path)
def build_cache_config(cache_path: Optional[str], organization: str) -> CacheConfig:
if cache_path is None:
return BlackHoleCacheConfig()
elif cache_path.startswith("mongodb:"):
return MongoCacheConfig(cache_path, collection_name=organization)
else:
return SqliteCacheConfig(os.path.join(cache_path, f"{organization}.sqlite"))
29 changes: 29 additions & 0 deletions src/helm/common/key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,32 @@ def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
def remove(self, key: Dict) -> None:
del self._sqlite_dict[key]
self._sqlite_dict.commit()


class BlackHoleKeyValueStore(KeyValueStore):
"""Key value store that discards all data."""

def __enter__(self) -> "BlackHoleKeyValueStore":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

def contains(self, key: Dict) -> bool:
return False

def get(self, key: Dict) -> Optional[Dict]:
return None

def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
return
yield

def put(self, key: Dict, value: Dict) -> None:
return None

def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
return None

def remove(self, key: Dict) -> None:
return None
16 changes: 7 additions & 9 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from dataclasses import replace
import os
from typing import Any, Dict, Mapping, Optional

from retrying import Attempt, RetryError
Expand Down Expand Up @@ -27,15 +27,13 @@ class AuthenticationError(NonRetriableException):
class AutoClient(Client):
"""Automatically dispatch to the proper `Client` based on the model deployment name."""

def __init__(self, credentials: Mapping[str, Any], cache_path: str, mongo_uri: str = ""):
self._auto_tokenizer = AutoTokenizer(credentials, cache_path, mongo_uri)
def __init__(self, credentials: Mapping[str, Any], cache_path: Optional[str]):
self._auto_tokenizer = AutoTokenizer(credentials, cache_path)
self.credentials = credentials
self.cache_path = cache_path
self.mongo_uri = mongo_uri
self.clients: Dict[str, Client] = {}
self._critique_client: Optional[CritiqueClient] = None
hlog(f"AutoClient: cache_path = {cache_path}")
hlog(f"AutoClient: mongo_uri = {mongo_uri}")

def _get_client(self, model_deployment_name: str) -> Client:
"""Return a client based on the model, creating it if necessary."""
Expand All @@ -62,7 +60,7 @@ def _get_client(self, model_deployment_name: str) -> Client:

# Prepare a cache
host_organization: str = model_deployment.host_organization
cache_config: CacheConfig = build_cache_config(self.cache_path, self.mongo_uri, host_organization)
cache_config: CacheConfig = build_cache_config(self.cache_path, host_organization)

client_spec = inject_object_spec_args(
model_deployment.client_spec,
Expand Down Expand Up @@ -144,7 +142,7 @@ def get_toxicity_classifier_client(self) -> ToxicityClassifierClient:
"""Get the toxicity classifier client. We currently only support Perspective API."""
from helm.proxy.clients.perspective_api_client import PerspectiveAPIClient

cache_config: CacheConfig = build_cache_config(self.cache_path, self.mongo_uri, "perspectiveapi")
cache_config: CacheConfig = build_cache_config(self.cache_path, "perspectiveapi")
return PerspectiveAPIClient(self.credentials.get("perspectiveApiKey", ""), cache_config)

def get_moderation_api_client(self):
Expand Down Expand Up @@ -178,7 +176,7 @@ def get_critique_client(self) -> CritiqueClient:
if not surgeai_credentials:
raise ValueError("surgeaiApiKey credentials are required for SurgeAICritiqueClient")
self._critique_client = SurgeAICritiqueClient(
surgeai_credentials, build_cache_config(self.cache_path, self.mongo_uri, "surgeai")
surgeai_credentials, build_cache_config(self.cache_path, "surgeai")
)
elif critique_type == "model":
from helm.proxy.critique.model_critique_client import ModelCritiqueClient
Expand All @@ -198,7 +196,7 @@ def get_critique_client(self) -> CritiqueClient:
if not scale_credentials:
raise ValueError("scaleApiKey is required for ScaleCritiqueClient")
self._critique_client = ScaleCritiqueClient(
scale_credentials, build_cache_config(self.cache_path, self.mongo_uri, "scale"), scale_project
scale_credentials, build_cache_config(self.cache_path, "scale"), scale_project
)
else:
raise ValueError(
Expand Down
19 changes: 8 additions & 11 deletions src/helm/proxy/clients/microsoft_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional, Dict

from filelock import FileLock
from openai.api_resources.abstract import engine_api_resource
import openai as turing
from threading import Lock

from helm.common.cache import CacheConfig
from helm.common.request import (
Expand All @@ -17,6 +17,12 @@
from .openai_client import ORIGINAL_COMPLETION_ATTRIBUTES


# The Microsoft Turing server only allows a single request at a time, so acquire a
# thread-safe lock before making a request.
# https://github.com/microsoft/turing-academic-TNLG#rate-limitations
_LOCK = Lock()


class MicrosoftClient(CachingClient):
"""
Client for the Microsoft's Megatron-Turing NLG models (https://arxiv.org/abs/2201.11990).
Expand Down Expand Up @@ -45,7 +51,6 @@ def convert_to_raw_request(request: Request) -> Dict:

def __init__(
self,
lock_file_path: str,
cache_config: CacheConfig,
api_key: Optional[str] = None,
org_id: Optional[str] = None,
Expand All @@ -65,14 +70,6 @@ def class_url(
self.api_base: str = "https://turingnlg-turingnlg-mstap-v2.turingase.p.azurewebsites.net"
self.completion_attributes = (EngineAPIResource,) + ORIGINAL_COMPLETION_ATTRIBUTES[1:]

# The Microsoft Turing server only allows a single request at a time, so acquire a
# process-safe lock before making a request.
# https://github.com/microsoft/turing-academic-TNLG#rate-limitations
#
# Since the model will generate roughly three tokens per second and the max context window
# is 2048 tokens, we expect the maximum time for a request to be fulfilled to be 700 seconds.
self._lock = FileLock(lock_file_path, timeout=700)

def make_request(self, request: Request) -> RequestResult:
"""
Make a request for the Microsoft MT-NLG models.
Expand Down Expand Up @@ -110,7 +107,7 @@ def make_request(self, request: Request) -> RequestResult:
try:

def do_it():
with self._lock:
with _LOCK:
# Following https://beta.openai.com/docs/api-reference/authentication
# `organization` can be set to None.
turing.organization = self.org_id
Expand Down
10 changes: 9 additions & 1 deletion src/helm/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
register_builtin_configs_from_helm_package,
)
from helm.common.authentication import Authentication
from helm.common.general import ensure_directory_exists
from helm.common.hierarchical_logger import hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import Request
from helm.common.perspective_api_request import PerspectiveAPIRequest
from helm.common.moderations_api_request import ModerationAPIRequest
from helm.common.tokenization_request import TokenizationRequest, DecodeRequest
from helm.proxy.services.service import CACHE_DIR
from .accounts import Account
from .services.server_service import ServerService
from .query import Query
Expand Down Expand Up @@ -258,7 +260,13 @@ def main():
register_builtin_configs_from_helm_package()
register_configs_from_directory(args.base_path)

service = ServerService(base_path=args.base_path, mongo_uri=args.mongo_uri)
cache_path: str
if args.mongo_uri:
cache_path = args.mongo_uri
else:
cache_path = os.path.join(args.base_path, CACHE_DIR)
ensure_directory_exists(cache_path)
service = ServerService(base_path=args.base_path, cache_path=cache_path)

gunicorn_args = {
"workers": args.workers,
Expand Down
Loading

0 comments on commit a6fc79a

Please sign in to comment.