Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Dec 15, 2023
1 parent 0f39de8 commit b0c8151
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/helm/benchmark/adaptation/adapters/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import shutil
import tempfile

from helm.common.authentication import Authentication
from helm.proxy.services.service import CACHE_DIR
from helm.proxy.services.server_service import ServerService
from helm.benchmark.window_services.tokenizer_service import TokenizerService

Expand All @@ -13,7 +15,7 @@ class TestAdapter:

def setup_method(self):
self.path: str = tempfile.mkdtemp()
service = ServerService(base_path=self.path, root_mode=True)
service = ServerService(base_path=self.path, root_mode=True, cache_path=os.path.join(self.path, CACHE_DIR))
self.tokenizer_service = TokenizerService(service, Authentication("test"))

def teardown_method(self, _):
Expand Down
4 changes: 3 additions & 1 deletion src/helm/benchmark/window_services/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from typing import List

from helm.common.authentication import Authentication
from helm.proxy.services.server_service import ServerService
from helm.proxy.services.service import CACHE_DIR
from helm.benchmark.metrics.metric_service import MetricService
from .tokenizer_service import TokenizerService

Expand Down Expand Up @@ -228,5 +230,5 @@


def get_tokenizer_service(local_path: str) -> TokenizerService:
service = ServerService(base_path=local_path, root_mode=True)
service = ServerService(base_path=local_path, root_mode=True, cache_path=os.path.join(local_path, CACHE_DIR))
return MetricService(service, Authentication("test"))
9 changes: 7 additions & 2 deletions src/helm/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from helm.common.request import Request
from helm.common.perspective_api_request import PerspectiveAPIRequest
from helm.common.tokenization_request import TokenizationRequest, DecodeRequest
from helm.proxy.services.service import CACHE_DIR
from .accounts import Account
from .services.server_service import ServerService
from .query import Query
Expand Down Expand Up @@ -224,8 +225,12 @@ def main():
default="",
)
args = parser.parse_args()

service = ServerService(base_path=args.base_path, mongo_uri=args.mongo_uri)
cache_path: str
if args.mongo_uri:
cache_path = args.mongo_uri
else:
cache_path = os.path.join(args.base_path, CACHE_DIR)
service = ServerService(base_path=args.base_path, cache_path=cache_path)

gunicorn_args = {
"workers": args.workers,
Expand Down

0 comments on commit b0c8151

Please sign in to comment.